diff --git a/backend/couchdb/auth.go b/backend/couchdb/auth.go index 0b44af8bd..17f2e969e 100644 --- a/backend/couchdb/auth.go +++ b/backend/couchdb/auth.go @@ -14,8 +14,8 @@ func (l *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { return l.transport.RoundTrip(req) } -func (rt *AuthTransport) Transport() http.RoundTripper { - return rt.transport +func (l *AuthTransport) Transport() http.RoundTripper { + return l.transport } func (l *AuthTransport) SetTransport(rt http.RoundTripper) { diff --git a/backend/couchdb/db.go b/backend/couchdb/db.go index e66386c52..c1e5a8f54 100644 --- a/backend/couchdb/db.go +++ b/backend/couchdb/db.go @@ -64,6 +64,7 @@ func clientAndDB(ctx context.Context, dbName string, cfg *Config) (*kivik.Client if db.Err() != nil { return nil, nil, db.Err() } + return client, db, err } @@ -74,6 +75,7 @@ func Client(cfg *Config) (*kivik.Client, error) { if err != nil { return nil, err } + rts := []transport.ChainableRoundTripper{ &AuthTransport{ Username: cfg.User, @@ -84,10 +86,11 @@ func Client(cfg *Config) (*kivik.Client, error) { if !cfg.DisableRequestLogging { rts = append(rts, &transport.LoggingRoundTripper{}) } + chain := transport.Chain(rts...) tr := couchdb.SetTransport(chain) - err = client.Authenticate(ctx, tr) - if err != nil { + + if err := client.Authenticate(ctx, tr); err != nil { return nil, err } @@ -96,9 +99,10 @@ func Client(cfg *Config) (*kivik.Client, error) { func ParseConfig() (*Config, error) { var cfg Config - err := env.Parse(&cfg) - if err != nil { + + if err := env.Parse(&cfg); err != nil { return nil, err } + return &cfg, nil } diff --git a/backend/couchdb/health_check.go b/backend/couchdb/health_check.go index c60e3fc52..dec9a6eca 100644 --- a/backend/couchdb/health_check.go +++ b/backend/couchdb/health_check.go @@ -7,7 +7,9 @@ import ( "time" kivik "github.com/go-kivik/kivik/v3" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" + "github.com/pace/bricks/maintenance/log" ) // HealthCheck checks the state of the object storage client. It must not be changed @@ -28,7 +30,7 @@ var ( // HealthCheck checks if the object storage client is healthy. If the last result is outdated, // object storage is checked for upload and download, -// otherwise returns the old result +// otherwise returns the old result. func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { if time.Since(h.state.LastChecked()) <= h.Config.HealthCheckResultTTL { // the last health check is not outdated, an can be reused. @@ -38,14 +40,16 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health checkTime := time.Now() var doc Doc + var err error + var row *kivik.Row check: // check if context was canceled select { case <-ctx.Done(): - h.state.SetErrorState(fmt.Errorf("failed: %v", ctx.Err())) + h.state.SetErrorState(fmt.Errorf("failed: %w", ctx.Err())) return h.state.GetState() default: } @@ -55,16 +59,22 @@ check: if kivik.StatusCode(row.Err) == http.StatusNotFound { goto put } - h.state.SetErrorState(fmt.Errorf("failed to get: %#v", row.Err)) + + h.state.SetErrorState(fmt.Errorf("failed to get: %w", row.Err)) + return h.state.GetState() } - defer row.Body.Close() + + defer func() { + if err := row.Body.Close(); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("Failed closing body") + } + }() // check if document exists if row.Rev != "" { - err = row.ScanDoc(&doc) - if err != nil { - h.state.SetErrorState(fmt.Errorf("failed to get: %v", row.Err)) + if err := row.ScanDoc(&doc); err != nil { + h.state.SetErrorState(fmt.Errorf("failed to get: %w", row.Err)) return h.state.GetState() } @@ -77,23 +87,27 @@ check: put: // update document doc.ID = h.Config.HealthCheckKey + doc.Time = time.Now().Format(healthCheckTimeFormat) + _, err = h.DB.Put(ctx, h.Config.HealthCheckKey, doc) if err != nil { // not yet created, try to create if h.Config.DatabaseAutoCreate && kivik.StatusCode(err) == http.StatusNotFound { - err := h.Client.CreateDB(ctx, h.Name) - if err != nil { - h.state.SetErrorState(fmt.Errorf("failed to put object: %v", err)) + if err := h.Client.CreateDB(ctx, h.Name); err != nil { + h.state.SetErrorState(fmt.Errorf("failed to put object: %w", err)) return h.state.GetState() } + goto put } if kivik.StatusCode(err) == http.StatusConflict { goto check } - h.state.SetErrorState(fmt.Errorf("failed to put object: %v", err)) + + h.state.SetErrorState(fmt.Errorf("failed to put object: %w", err)) + return h.state.GetState() } @@ -103,6 +117,7 @@ put: healthy: // If uploading and downloading worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } @@ -116,7 +131,7 @@ type Doc struct { // time span concurrent request to the objstore may break the assumption // that the value is the same, but in this case it would be acceptable. // Assumption all instances are created equal and one providing evidence -// of a good write would be sufficient. See #244 +// of a good write would be sufficient. See #244. func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { t, err := time.Parse(healthCheckTimeFormat, observedValue) if err == nil { @@ -124,7 +139,7 @@ func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { allowedEnd := checkTime.Add(healthCheckConcurrentSpan) // timestamp we got from the document is in allowed range - // concider it healthy + // consider it healthy return t.After(allowedStart) && t.Before(allowedEnd) } diff --git a/backend/k8sapi/client.go b/backend/k8sapi/client.go index 78ffda131..3a3d67412 100644 --- a/backend/k8sapi/client.go +++ b/backend/k8sapi/client.go @@ -15,24 +15,25 @@ import ( "strings" "github.com/caarlos0/env/v10" + "github.com/pace/bricks/http/transport" "github.com/pace/bricks/maintenance/log" ) -// Client minimal client for the kubernetes API +// Client minimal client for the kubernetes API. type Client struct { Podname string Namespace string CACert []byte Token string cfg Config - HttpClient *http.Client + HTTPClient *http.Client } -// NewClient create new api client +// NewClient create new api client. func NewClient() (*Client, error) { cl := Client{ - HttpClient: &http.Client{}, + HTTPClient: &http.Client{}, } // lookup hostname (for pod update) @@ -40,51 +41,58 @@ func NewClient() (*Client, error) { if err != nil { return nil, err } + cl.Podname = hostname // parse environment including secrets mounted by kubernetes - err = env.Parse(&cl.cfg) - if err != nil { + if err := env.Parse(&cl.cfg); err != nil { return nil, err } caData, err := os.ReadFile(cl.cfg.CACertFile) if err != nil { - return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.CACertFile, err) + return nil, fmt.Errorf("failed to read %q: %w", cl.cfg.CACertFile, err) } + cl.CACert = []byte(strings.TrimSpace(string(caData))) namespaceData, err := os.ReadFile(cl.cfg.NamespaceFile) if err != nil { - return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.NamespaceFile, err) + return nil, fmt.Errorf("failed to read %q: %w", cl.cfg.NamespaceFile, err) } + cl.Namespace = strings.TrimSpace(string(namespaceData)) tokenData, err := os.ReadFile(cl.cfg.TokenFile) if err != nil { - return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.CACertFile, err) + return nil, fmt.Errorf("failed to read %q: %w", cl.cfg.CACertFile, err) } + cl.Token = strings.TrimSpace(string(tokenData)) // add kubernetes api server cert chain := transport.NewDefaultTransportChain() pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(cl.CACert) if !ok { return nil, fmt.Errorf("failed to load kubernetes ca cert") } + chain.Final(&http.Transport{ TLSClientConfig: &tls.Config{ - RootCAs: pool, + RootCAs: pool, + MinVersion: tls.VersionTLS12, }, }) - cl.HttpClient.Transport = chain + + cl.HTTPClient.Transport = chain return &cl, nil } // SimpleRequest send a simple http request to kubernetes with the passed -// method, url and requestObj, decoding the result into responseObj +// method, url and requestObj, decoding the result into responseObj. func (c *Client) SimpleRequest(ctx context.Context, method, url string, requestObj, responseObj interface{}) error { data, err := json.Marshal(requestObj) if err != nil { @@ -99,16 +107,22 @@ func (c *Client) SimpleRequest(ctx context.Context, method, url string, requestO req.Header.Set("Content-Type", "application/json-patch+json") req.Header.Set("Authorization", "Bearer "+c.Token) - resp, err := c.HttpClient.Do(req) + resp, err := c.HTTPClient.Do(req) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("failed to do api request") return err } - defer resp.Body.Close() + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("failed to close response body") + } + }() if resp.StatusCode > 299 { - body, _ := io.ReadAll(resp.Body) // nolint: errcheck + body, _ := io.ReadAll(resp.Body) log.Ctx(ctx).Debug().Msgf("failed to do api request, due to: %s", string(body)) + return fmt.Errorf("k8s request failed with %s", resp.Status) } diff --git a/backend/k8sapi/config.go b/backend/k8sapi/config.go index 901395c74..dabba28af 100644 --- a/backend/k8sapi/config.go +++ b/backend/k8sapi/config.go @@ -3,7 +3,7 @@ package k8sapi // Config gathers the required kubernetes system configuration to use the -// kubernetes API +// kubernetes API. type Config struct { Host string `env:"KUBERNETES_SERVICE_HOST" envDefault:"localhost"` Port int `env:"KUBERNETES_PORT_443_TCP_PORT" envDefault:"433"` diff --git a/backend/k8sapi/pod.go b/backend/k8sapi/pod.go index 2d0a8ebac..03114adee 100644 --- a/backend/k8sapi/pod.go +++ b/backend/k8sapi/pod.go @@ -9,13 +9,13 @@ import ( ) // SetCurrentPodLabel set the label for the current pod in the current -// namespace (requires patch on pods resource) +// namespace (requires patch on pods resource). func (c *Client) SetCurrentPodLabel(ctx context.Context, label, value string) error { return c.SetPodLabel(ctx, c.Namespace, c.Podname, label, value) } // SetPodLabel sets the label and value for the pod of the given namespace -// (requires patch on pods resource in the given namespace) +// (requires patch on pods resource in the given namespace). func (c *Client) SetPodLabel(ctx context.Context, namespace, podname, label, value string) error { pr := []struct { Op string `json:"op"` @@ -30,6 +30,7 @@ func (c *Client) SetPodLabel(ctx context.Context, namespace, podname, label, val } url := fmt.Sprintf("https://%s:%d/api/v1/namespaces/%s/pods/%s", c.cfg.Host, c.cfg.Port, namespace, podname) + var resp interface{} return c.SimpleRequest(ctx, http.MethodPatch, url, &pr, &resp) diff --git a/backend/objstore/health_objstore.go b/backend/objstore/health_objstore.go index fd268d614..9dd6ffea6 100644 --- a/backend/objstore/health_objstore.go +++ b/backend/objstore/health_objstore.go @@ -8,6 +8,7 @@ import ( "time" "github.com/minio/minio-go/v7" + "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" @@ -27,7 +28,7 @@ var ( // HealthCheck checks if the object storage client is healthy. If the last result is outdated, // object storage is checked for upload and download, -// otherwise returns the old result +// otherwise returns the old result. func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { if time.Since(h.state.LastChecked()) <= cfg.HealthCheckResultTTL { // the last health check is not outdated, an can be reused. @@ -49,7 +50,7 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health }, ) if err != nil { - h.state.SetErrorState(fmt.Errorf("failed to put object: %v", err)) + h.state.SetErrorState(fmt.Errorf("failed to put object: %w", err)) return h.state.GetState() } @@ -58,6 +59,7 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health defer func() { go func() { defer errors.HandleWithCtx(ctx, "HealthCheck remove s3 object version") + ctx := log.WithContext(context.Background()) err = h.Client.RemoveObject( @@ -88,15 +90,20 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health }, ) if err != nil { - h.state.SetErrorState(fmt.Errorf("failed to get object: %v", err)) + h.state.SetErrorState(fmt.Errorf("failed to get object: %w", err)) return h.state.GetState() } - defer obj.Close() + + defer func() { + if err := obj.Close(); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("Failed closing object") + } + }() // Assert expectations buf, err := io.ReadAll(obj) if err != nil { - h.state.SetErrorState(fmt.Errorf("failed to compare object: %v", err)) + h.state.SetErrorState(fmt.Errorf("failed to compare object: %w", err)) return h.state.GetState() } @@ -106,12 +113,14 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } h.state.SetErrorState(fmt.Errorf("unexpected content: %q <-> %q", string(buf), string(expContent))) + return h.state.GetState() } healthy: // If uploading and downloading worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } @@ -119,7 +128,7 @@ healthy: // time span concurrent request to the objstore may break the assumption // that the value is the same, but in this case it would be acceptable. // Assumption all instances are created equal and one providing evidence -// of a good write would be sufficient. See #244 +// of a good write would be sufficient. See #244. func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { t, err := time.Parse(healthCheckTimeFormat, observedValue) if err == nil { @@ -127,7 +136,7 @@ func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { allowedEnd := checkTime.Add(healthCheckConcurrentSpan) // timestamp we got from the document is in allowed range - // concider it healthy + // consider it healthy return t.After(allowedStart) && t.Before(allowedEnd) } diff --git a/backend/objstore/health_objstore_test.go b/backend/objstore/health_objstore_test.go index 9c0eda18c..c0888f57a 100644 --- a/backend/objstore/health_objstore_test.go +++ b/backend/objstore/health_objstore_test.go @@ -8,30 +8,39 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + http2 "github.com/pace/bricks/http" "github.com/pace/bricks/maintenance/log" - "github.com/stretchr/testify/assert" ) func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) - resp := rec.Result() - defer resp.Body.Close() - return resp + + return rec.Result() } -// TestIntegrationHealthCheck tests if object storage health check ist working like expected +// TestIntegrationHealthCheck tests if object storage health check ist working like expected. func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + RegisterHealthchecks() time.Sleep(1 * time.Second) // by the magic of asynchronous code, I here-by present a magic wait + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -39,6 +48,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data), "objstore OK") { t.Errorf("Expected /health/check to return OK, got: %s", string(data)) } @@ -46,6 +56,7 @@ func TestIntegrationHealthCheck(t *testing.T) { func TestConcurrentHealth(t *testing.T) { ct := time.Date(2020, 12, 16, 15, 30, 46, 0, time.UTC) + tests := []struct { name string checkTime time.Time diff --git a/backend/objstore/objstore.go b/backend/objstore/objstore.go index 36a1b596f..f0b068e0d 100644 --- a/backend/objstore/objstore.go +++ b/backend/objstore/objstore.go @@ -33,15 +33,17 @@ func RegisterHealthchecks() { registerHealthchecks() } -// deprecated consider using DefaultClientFromEnv +// Client returns the default client. +// Deprecated: consider using DefaultClientFromEnv. func Client() (*minio.Client, error) { return DefaultClientFromEnv() } -// Client with environment based configuration. Registers healthchecks automatically. If yo do not want to use healthchecks +// DefaultClientFromEnv with environment based configuration. Registers healthchecks automatically. If yo do not want to use healthchecks // consider calling CustomClient. func DefaultClientFromEnv() (*minio.Client, error) { registerHealthchecks() + return CustomClient(cfg.Endpoint, &minio.Options{ Secure: cfg.UseSSL, Region: cfg.Region, @@ -50,17 +52,20 @@ func DefaultClientFromEnv() (*minio.Client, error) { }) } -// CustomClient with customized client +// CustomClient with customized client. func CustomClient(endpoint string, opts *minio.Options) (*minio.Client, error) { opts.Transport = newCustomTransport(endpoint) + client, err := minio.New(endpoint, opts) if err != nil { return nil, err } + log.Logger().Info().Str("endpoint", endpoint). Str("region", opts.Region). Bool("ssl", opts.Secure). Msg("S3 connection created") + return client, nil } @@ -78,8 +83,7 @@ var register = &sync.Once{} func registerHealthchecks() { register.Do(func() { // parse log config - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse object storage environment: %v", err) } @@ -94,6 +98,7 @@ func registerHealthchecks() { if err != nil { log.Warnf("Failed to create check for bucket: %v", err) } + if !ok { err := client.MakeBucket(ctx, cfg.HealthCheckBucketName, minio.MakeBucketOptions{ Region: cfg.Region, @@ -102,6 +107,7 @@ func registerHealthchecks() { log.Warnf("Failed to create bucket: %v", err) } } + servicehealthcheck.RegisterHealthCheck("objstore", &HealthCheck{ Client: client, }) diff --git a/backend/postgres/errors.go b/backend/postgres/errors.go index b72d5ee20..db5ca63e7 100644 --- a/backend/postgres/errors.go +++ b/backend/postgres/errors.go @@ -19,14 +19,14 @@ func IsErrConnectionFailed(err error) bool { } // go-pg has this check internally for network errors - _, ok := err.(net.Error) - if ok { + var netErr net.Error + if errors.As(err, &netErr) { return true } // go-pg has similar check for integrity violation issues, here we check network issues - pgErr, ok := err.(pg.Error) - if ok { + var pgErr pg.Error + if errors.As(err, &pgErr) { code := pgErr.Field('C') // We check on error codes of Class 08 — Connection Exception. // https://www.postgresql.org/docs/10/errcodes-appendix.html @@ -34,5 +34,6 @@ func IsErrConnectionFailed(err error) bool { return true } } + return false } diff --git a/backend/postgres/health_postgres.go b/backend/postgres/health_postgres.go index 5b5dfe866..f6e1c7f58 100644 --- a/backend/postgres/health_postgres.go +++ b/backend/postgres/health_postgres.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-pg/pg/orm" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" ) @@ -21,7 +22,7 @@ type postgresQueryExecutor interface { Exec(ctx context.Context, query interface{}, params ...interface{}) (res orm.Result, err error) } -// Init initializes the test table +// Init initializes the test table. func (h *HealthCheck) Init(ctx context.Context) error { _, errWrite := h.Pool.Exec(ctx, `CREATE TABLE IF NOT EXISTS `+cfg.HealthCheckTableName+`(ok boolean);`) return errWrite @@ -55,6 +56,7 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } // If no error occurred set the State of this Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } diff --git a/backend/postgres/health_postgres_test.go b/backend/postgres/health_postgres_test.go index 2bb978fd6..517402741 100644 --- a/backend/postgres/health_postgres_test.go +++ b/backend/postgres/health_postgres_test.go @@ -12,20 +12,28 @@ import ( "time" "github.com/go-pg/pg/orm" + "github.com/stretchr/testify/require" + http2 "github.com/pace/bricks/http" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" - "github.com/stretchr/testify/require" ) func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) + resp := rec.Result() - defer resp.Body.Close() + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() + return resp } @@ -33,9 +41,18 @@ func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + time.Sleep(1 * time.Second) // by the magic of asynchronous code, I here-by present a magic wait + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -43,6 +60,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data[:]), "postgresdefault OK") { t.Errorf("Expected /health/check to return OK, got: %q", string(data[:])) } @@ -68,6 +86,7 @@ func TestHealthCheckCaching(t *testing.T) { // get the error for the first time require.Equal(t, servicehealthcheck.Err, res.State) require.Equal(t, "TestHealthCheckCaching", res.Msg) + res = h.HealthCheck(ctx) pool.err = nil // getting the cached error diff --git a/backend/postgres/metrics.go b/backend/postgres/metrics.go index 69766aaf3..555b6d93e 100644 --- a/backend/postgres/metrics.go +++ b/backend/postgres/metrics.go @@ -75,6 +75,7 @@ func NewConnectionPoolMetrics() *ConnectionPoolMetrics { []string{"database", "pool"}, ), } + return &m } @@ -121,7 +122,9 @@ func (m *ConnectionPoolMetrics) ObserveRegularly(ctx context.Context, db *pg.DB, // cleaning up the related resources. go func() { ticker := time.NewTicker(time.Minute) + defer close(trigger) + for { select { case <-ticker.C: @@ -143,7 +146,7 @@ func (m *ConnectionPoolMetrics) ObserveRegularly(ctx context.Context, db *pg.DB, } // ObserveWhenTriggered starts observing the given postgres pool. The pool name -// behaves as decribed for the ObserveRegularly method. The metrics are observed +// behaves as described for the ObserveRegularly method. The metrics are observed // for every emitted value from the trigger channel. The trigger channel allows // passing a response channel that will be closed once the metrics were // collected. It is also possible to pass nil. You should close the trigger @@ -152,13 +155,16 @@ func (m *ConnectionPoolMetrics) ObserveWhenTriggered(trigger <-chan chan<- struc // check that pool name is unique m.poolMetricsMx.Lock() defer m.poolMetricsMx.Unlock() + if _, ok := m.poolMetrics[poolName]; ok { return fmt.Errorf("invalid pool name: %q: %w", poolName, ErrNotUnique) } + m.poolMetrics[poolName] = struct{}{} // start goroutine go m.gatherConnectionPoolMetrics(trigger, db, poolName) + return nil } @@ -188,6 +194,7 @@ func (m *ConnectionPoolMetrics) gatherConnectionPoolMetrics(trigger <-chan chan< if done != nil { close(done) } + prevStats = *stats } } diff --git a/backend/postgres/metrics_test.go b/backend/postgres/metrics_test.go index 98ecd8ffa..d9cc27310 100644 --- a/backend/postgres/metrics_test.go +++ b/backend/postgres/metrics_test.go @@ -5,6 +5,7 @@ package postgres_test import ( "context" "errors" + "net/http" "net/http/httptest" "testing" "time" @@ -25,6 +26,7 @@ func ExampleConnectionPoolMetrics() { if err := metrics.ObserveRegularly(context.Background(), myDB, "my_db"); err != nil { panic(err) } + prometheus.MustRegister(metrics) } @@ -36,6 +38,7 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { metricsRegistry := prometheus.NewRegistry() metrics := NewConnectionPoolMetrics() metricsRegistry.MustRegister(metrics) + db := ConnectionPool() trigger := make(chan chan<- struct{}) err := metrics.ObserveWhenTriggered(trigger, db, "test") @@ -44,6 +47,7 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { if _, err := db.Exec(`SELECT 1;`); err != nil { t.Fatalf("could not query postgres database: %s", err) } + whenDone := make(chan struct{}) select { case trigger <- whenDone: @@ -58,7 +62,8 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { // query metrics resp := httptest.NewRecorder() handler := promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) - handler.ServeHTTP(resp, httptest.NewRequest("GET", "/metrics", nil)) + handler.ServeHTTP(resp, httptest.NewRequest(http.MethodGet, "/metrics", nil)) + body := resp.Body.String() assert.Regexp(t, `pace_postgres_connection_pool_hits.*?\Wpool="test"\W`, body) assert.Regexp(t, `pace_postgres_connection_pool_misses.*?\Wpool="test"\W`, body) diff --git a/backend/postgres/options.go b/backend/postgres/options.go index cca80268b..c3965c0dc 100644 --- a/backend/postgres/options.go +++ b/backend/postgres/options.go @@ -15,35 +15,35 @@ func WithQueryLogging(logRead, logWrite bool) ConfigOption { } } -// WithPort - customize the db port +// WithPort - customize the db port. func WithPort(port int) ConfigOption { return func(cfg *Config) { cfg.Port = port } } -// WithHost - customise the db host +// WithHost - customise the db host. func WithHost(host string) ConfigOption { return func(cfg *Config) { cfg.Host = host } } -// WithPassword - customise the db password +// WithPassword - customise the db password. func WithPassword(password string) ConfigOption { return func(cfg *Config) { cfg.Password = password } } -// WithUser - customise the db user +// WithUser - customise the db user. func WithUser(user string) ConfigOption { return func(cfg *Config) { cfg.User = user } } -// WithDatabase - customise the db name +// WithDatabase - customise the db name. func WithDatabase(database string) ConfigOption { return func(cfg *Config) { cfg.Database = database @@ -161,14 +161,14 @@ func WithIdleCheckFrequency(idleCheckFrequency time.Duration) ConfigOption { } } -// WithHealthCheckTableName - Name of the Table that is created to try if database is writeable +// WithHealthCheckTableName - Name of the Table that is created to try if database is writeable. func WithHealthCheckTableName(healthCheckTableName string) ConfigOption { return func(cfg *Config) { cfg.HealthCheckTableName = healthCheckTableName } } -// WithHealthCheckResultTTL - Amount of time to cache the last health check result +// WithHealthCheckResultTTL - Amount of time to cache the last health check result. func WithHealthCheckResultTTL(healthCheckResultTTL time.Duration) ConfigOption { return func(cfg *Config) { cfg.HealthCheckResultTTL = healthCheckResultTTL diff --git a/backend/postgres/options_test.go b/backend/postgres/options_test.go index 0ea808de4..08b131c8b 100644 --- a/backend/postgres/options_test.go +++ b/backend/postgres/options_test.go @@ -10,7 +10,9 @@ import ( func TestWithApplicationName(t *testing.T) { param := "ApplicationName" + var conf Config + f := WithApplicationName(param) f(&conf) require.Equal(t, conf.ApplicationName, param) @@ -18,7 +20,9 @@ func TestWithApplicationName(t *testing.T) { func TestWithDatabase(t *testing.T) { param := "Database" + var conf Config + f := WithDatabase(param) f(&conf) require.Equal(t, conf.Database, param) @@ -26,7 +30,9 @@ func TestWithDatabase(t *testing.T) { func TestWithDialTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithDialTimeout(param) f(&conf) require.Equal(t, conf.DialTimeout, param) @@ -34,7 +40,9 @@ func TestWithDialTimeout(t *testing.T) { func TestWithHealthCheckResultTTL(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithHealthCheckResultTTL(param) f(&conf) require.Equal(t, conf.HealthCheckResultTTL, param) @@ -42,7 +50,9 @@ func TestWithHealthCheckResultTTL(t *testing.T) { func TestWithHealthCheckTableName(t *testing.T) { param := "HealthCheckTableName" + var conf Config + f := WithHealthCheckTableName(param) f(&conf) require.Equal(t, conf.HealthCheckTableName, param) @@ -50,7 +60,9 @@ func TestWithHealthCheckTableName(t *testing.T) { func TestWithHost(t *testing.T) { param := "Host" + var conf Config + f := WithHost(param) f(&conf) require.Equal(t, conf.Host, param) @@ -58,7 +70,9 @@ func TestWithHost(t *testing.T) { func TestWithIdleCheckFrequency(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithIdleCheckFrequency(param) f(&conf) require.Equal(t, conf.IdleCheckFrequency, param) @@ -66,7 +80,9 @@ func TestWithIdleCheckFrequency(t *testing.T) { func TestWithIdleTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithIdleTimeout(param) f(&conf) require.Equal(t, conf.IdleTimeout, param) @@ -74,7 +90,9 @@ func TestWithIdleTimeout(t *testing.T) { func TestWithMaxConnAge(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMaxConnAge(param) f(&conf) require.Equal(t, conf.MaxConnAge, param) @@ -82,7 +100,9 @@ func TestWithMaxConnAge(t *testing.T) { func TestWithMaxRetries(t *testing.T) { param := 42 + var conf Config + f := WithMaxRetries(param) f(&conf) require.Equal(t, conf.MaxRetries, param) @@ -90,7 +110,9 @@ func TestWithMaxRetries(t *testing.T) { func TestWithMaxRetryBackoff(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMaxRetryBackoff(param) f(&conf) require.Equal(t, conf.MaxRetryBackoff, param) @@ -98,7 +120,9 @@ func TestWithMaxRetryBackoff(t *testing.T) { func TestWithMinIdleConns(t *testing.T) { param := 42 + var conf Config + f := WithMinIdleConns(param) f(&conf) require.Equal(t, conf.MinIdleConns, param) @@ -106,7 +130,9 @@ func TestWithMinIdleConns(t *testing.T) { func TestWithMinRetryBackoff(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMinRetryBackoff(param) f(&conf) require.Equal(t, conf.MinRetryBackoff, param) @@ -114,7 +140,9 @@ func TestWithMinRetryBackoff(t *testing.T) { func TestWithPassword(t *testing.T) { param := "Password" + var conf Config + f := WithPassword(param) f(&conf) require.Equal(t, conf.Password, param) @@ -122,7 +150,9 @@ func TestWithPassword(t *testing.T) { func TestWithPoolSize(t *testing.T) { param := 42 + var conf Config + f := WithPoolSize(param) f(&conf) require.Equal(t, conf.PoolSize, param) @@ -130,7 +160,9 @@ func TestWithPoolSize(t *testing.T) { func TestWithPoolTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithPoolTimeout(param) f(&conf) require.Equal(t, conf.PoolTimeout, param) @@ -138,7 +170,9 @@ func TestWithPoolTimeout(t *testing.T) { func TestWithPort(t *testing.T) { param := 42 + var conf Config + f := WithPort(param) f(&conf) require.Equal(t, conf.Port, param) @@ -146,7 +180,9 @@ func TestWithPort(t *testing.T) { func TestWithReadTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithReadTimeout(param) f(&conf) require.Equal(t, conf.ReadTimeout, param) @@ -154,7 +190,9 @@ func TestWithReadTimeout(t *testing.T) { func TestWithRetryStatementTimeout(t *testing.T) { param := true + var conf Config + f := WithRetryStatementTimeout(param) f(&conf) require.Equal(t, conf.RetryStatementTimeout, param) @@ -162,7 +200,9 @@ func TestWithRetryStatementTimeout(t *testing.T) { func TestWithUser(t *testing.T) { param := "User" + var conf Config + f := WithUser(param) f(&conf) require.Equal(t, conf.User, param) @@ -170,7 +210,9 @@ func TestWithUser(t *testing.T) { func TestWithWriteTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithWriteTimeout(param) f(&conf) require.Equal(t, conf.WriteTimeout, param) @@ -194,7 +236,9 @@ func TestWithLogReadWriteOnly(t *testing.T) { for _, tc := range cases { read := tc[0] write := tc[1] + var conf Config + f := WithQueryLogging(read, write) f(&conf) assert.Equal(t, conf.LogRead, read) diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 9ed8ba53b..9d068e9ab 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -14,13 +14,11 @@ import ( "sync" "time" - "github.com/getsentry/sentry-go" - - "github.com/rs/zerolog" - "github.com/caarlos0/env/v10" + "github.com/getsentry/sentry-go" "github.com/go-pg/pg" "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" @@ -34,7 +32,7 @@ type Config struct { Database string `env:"POSTGRES_DB" envDefault:"postgres"` // ApplicationName is the application name. Used in logs on Pg side. - // Only availaible from pg-9.0. + // Only available from pg-9.0. ApplicationName string `env:"POSTGRES_APPLICATION_NAME" envDefault:"-"` // Maximum number of retries before giving up. MaxRetries int `env:"POSTGRES_MAX_RETRIES" envDefault:"5"` @@ -133,8 +131,7 @@ func init() { prometheus.MustRegister(metricQueryAffectedTotal) // parse log Config - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse postgres environment: %v", err) } @@ -169,18 +166,22 @@ var ( // logging and metrics. func DefaultConnectionPool() *pg.DB { var err error + defaultPoolOnce.Do(func() { if defaultPool == nil { defaultPool = ConnectionPool() // add metrics metrics := NewConnectionPoolMetrics() prometheus.MustRegister(metrics) + err = metrics.ObserveRegularly(context.Background(), defaultPool, "default") } }) + if err != nil { panic(err) } + return defaultPool } @@ -231,6 +232,7 @@ func CustomConnectionPool(opts *pg.Options) *pg.DB { Str("database", opts.Database). Str("as", opts.ApplicationName). Msg("PostgreSQL connection pool created") + db := pg.Connect(opts) if cfg.LogWrite || cfg.LogRead { db.OnQueryProcessed(queryLogger) @@ -259,6 +261,7 @@ func determineQueryMode(qry string) queryMode { if strings.HasPrefix(strings.ToLower(strings.TrimSpace(qry)), "select") { return readMode } + return writeMode } @@ -268,16 +271,17 @@ func queryLogger(event *pg.QueryProcessedEvent) { if !(cfg.LogRead || cfg.LogWrite) { return } - // we can only and should only perfom the following check if we have the information availaible + // we can only and should only perfom the following check if we have the information available mode := determineQueryMode(q) if mode == readMode && !cfg.LogRead { return } + if mode == writeMode && !cfg.LogWrite { return } - } + ctx := event.DB.Context() dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) @@ -310,6 +314,7 @@ func queryLogger(event *pg.QueryProcessedEvent) { // this is only a display issue not a "real" issue le.Msgf("%v", qe) } + le.Msg(q) } @@ -326,6 +331,7 @@ func getQueryType(s string) string { if len(p) > 0 { return strings.ToUpper(s[:p[0]]) } + return strings.ToUpper(s) } @@ -374,5 +380,6 @@ func metricsAdapter(event *pg.QueryProcessedEvent, opts *pg.Options) { metricQueryRowsTotal.With(labels).Add(float64(r.RowsReturned())) metricQueryAffectedTotal.With(labels).Add(math.Max(0, float64(r.RowsAffected()))) } + metricQueryDurationSeconds.With(labels).Observe(dur) } diff --git a/backend/postgres/postgres_test.go b/backend/postgres/postgres_test.go index 5206720e9..144e293a3 100644 --- a/backend/postgres/postgres_test.go +++ b/backend/postgres/postgres_test.go @@ -12,15 +12,17 @@ func TestIntegrationConnectionPool(t *testing.T) { if testing.Short() { t.SkipNow() } + db := ConnectionPool() + var result struct { Calc int } + _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck if err != nil { t.Errorf("got %v", err) } - // Note: This test can't actually test the logging correctly // but the code will be accessed } @@ -29,15 +31,17 @@ func TestIntegrationConnectionPoolNoLogging(t *testing.T) { if testing.Short() { t.SkipNow() } + db := ConnectionPool(WithQueryLogging(false, false)) + var result struct { Calc int } + _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck if err != nil { t.Errorf("got %v", err) } - // Note: This test can't actually test the logging correctly // but the code will be accessed } diff --git a/backend/queue/config.go b/backend/queue/config.go index aa908fa4a..00b355ce2 100644 --- a/backend/queue/config.go +++ b/backend/queue/config.go @@ -20,8 +20,7 @@ type config struct { var cfg config func init() { - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse queue environment: %v", err) } } diff --git a/backend/queue/metrics.go b/backend/queue/metrics.go index 03769e321..1fe8a6671 100644 --- a/backend/queue/metrics.go +++ b/backend/queue/metrics.go @@ -5,10 +5,11 @@ import ( "time" "github.com/adjust/rmq/v5" + "github.com/prometheus/client_golang/prometheus" + pberrors "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/routine" - "github.com/prometheus/client_golang/prometheus" ) type queueStatsGauges struct { @@ -31,11 +32,13 @@ func gatherMetrics(connection rmq.Connection) { log.Ctx(ctx).Debug().Err(err).Msg("rmq metrics: could not get open queues") pberrors.Handle(ctx, err) } + stats, err := connection.CollectStats(queues) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq metrics: could not collect stats") pberrors.Handle(ctx, err) } + for queue, queueStats := range stats.QueueStats { labels := prometheus.Labels{ "queue": queue, @@ -50,7 +53,7 @@ func gatherMetrics(connection rmq.Connection) { }) } -func registerConnection(connection rmq.Connection) queueStatsGauges { +func registerConnection(_ rmq.Connection) queueStatsGauges { gauges := queueStatsGauges{ readyGauge: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "rmq", diff --git a/backend/queue/rmq.go b/backend/queue/rmq.go index 83e76111e..25f9c35fe 100644 --- a/backend/queue/rmq.go +++ b/backend/queue/rmq.go @@ -6,13 +6,13 @@ import ( "sync" "time" + "github.com/adjust/rmq/v5" + "github.com/pace/bricks/backend/redis" pberrors "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/routine" - - "github.com/adjust/rmq/v5" ) var ( @@ -32,29 +32,34 @@ type queueHealth struct { func (h *queueHealth) isMarkedHealthy() bool { h.mu.Lock() defer h.mu.Unlock() + return h.markedUnhealthyAt.IsZero() } func (h *queueHealth) markUnhealthy() { h.mu.Lock() defer h.mu.Unlock() + h.markedUnhealthyAt = time.Now() } func (h *queueHealth) markHealthy() { h.mu.Lock() defer h.mu.Unlock() + h.markedUnhealthyAt = time.Time{} } func (h *queueHealth) getMarkedUnhealthyAt() time.Time { h.mu.Lock() defer h.mu.Unlock() + return h.markedUnhealthyAt } func initDefault() error { var err error + initMutex.Lock() defer initMutex.Unlock() @@ -67,9 +72,8 @@ func initDefault() error { ctx := log.ContextWithSink(log.WithContext(context.Background()), new(log.Sink)) routine.Run(ctx, func(ctx context.Context) { for { - err := <-errChan - if err != nil { - pberrors.Handle(ctx, fmt.Errorf("rmq reported error in background task: %s", err)) + if err := <-errChan; err != nil { + pberrors.Handle(ctx, fmt.Errorf("rmq reported error in background task: %w", err)) } } }) @@ -79,8 +83,10 @@ func initDefault() error { rmqConnection = nil return err } + gatherMetrics(rmqConnection) servicehealthcheck.RegisterHealthCheck("rmq", &HealthCheck{}) + return nil } @@ -88,20 +94,23 @@ func initDefault() error { // Whenever the number of items in the queue exceeds the healthyLimit // The queue will be reported as unhealthy // If the queue has already been opened, it will just be returned. Limits will not -// be updated +// be updated. func NewQueue(name string, healthyLimit int) (rmq.Queue, error) { - err := initDefault() - if err != nil { + if err := initDefault(); err != nil { return nil, err } + queue, err := rmqConnection.OpenQueue(name) if err != nil { return nil, err } + if _, ok := queueHealthLimits.Load(name); ok { return queue, nil } + queueHealthLimits.Store(name, &queueHealth{limit: healthyLimit}) + return queue, nil } @@ -113,7 +122,7 @@ type HealthCheck struct { } // HealthCheck checks if the queues are healthy, i.e. whether the number of -// items accumulated is below the healthyLimit defined when opening the queue +// items accumulated is below the healthyLimit defined when opening the queue. func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { if !h.IgnoreInterval && time.Since(h.state.LastChecked()) <= cfg.HealthCheckResultTTL { return h.state.GetState() @@ -122,23 +131,34 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health queues, err := rmqConnection.GetOpenQueues() if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq HealthCheck: could not get open queues") - h.state.SetErrorState(fmt.Errorf("error while retrieving open queues: %s", err)) + h.state.SetErrorState(fmt.Errorf("error while retrieving open queues: %w", err)) + return h.state.GetState() } + stats, err := rmqConnection.CollectStats(queues) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq HealthCheck: could not collect stats") - h.state.SetErrorState(fmt.Errorf("error while collecting stats: %s", err)) + h.state.SetErrorState(fmt.Errorf("error while collecting stats: %w", err)) + return h.state.GetState() } + queueHealthLimits.Range(func(k, v interface{}) bool { - name := k.(string) - hl := v.(*queueHealth) + name, _ := k.(string) + + hl, ok := v.(*queueHealth) + if !ok { + return false + } + stat := stats.QueueStats[name] + if stat.ReadyCount > int64(hl.limit) { if hl.isMarkedHealthy() { hl.markUnhealthy() h.state.SetHealthy() + return true } // queue health is still pending @@ -146,12 +166,16 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health return true } - h.state.SetErrorState(fmt.Errorf("Queue '%s' exceeded safe health limit of '%d'", name, hl.limit)) + h.state.SetErrorState(fmt.Errorf("queue '%s' exceeded safe health limit of '%d'", name, hl.limit)) + return false } + h.state.SetHealthy() hl.markHealthy() + return true }) + return h.state.GetState() } diff --git a/backend/queue/rmq_test.go b/backend/queue/rmq_test.go index 25f263d54..5636c8830 100644 --- a/backend/queue/rmq_test.go +++ b/backend/queue/rmq_test.go @@ -5,23 +5,28 @@ import ( "testing" "time" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/maintenance/log" ) func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + ctx := log.WithContext(context.Background()) cfg.HealthCheckPendingStateInterval = time.Second * 2 q1, err := NewQueue("integrationTestTasks", 1) assert.NoError(t, err) + err = q1.Publish("nothing here") assert.NoError(t, err) time.Sleep(time.Second) + check := &HealthCheck{IgnoreInterval: true} + res := check.HealthCheck(ctx) if res.State != "OK" { t.Errorf("Expected health check to be OK for a non-full queue: state %s, message: %s", res.State, res.Msg) @@ -37,12 +42,14 @@ func TestIntegrationHealthCheck(t *testing.T) { } // queue health pending time.Sleep(time.Second) + res = check.HealthCheck(ctx) if res.State != "OK" { t.Errorf("Expected health check to be OK") } // queue health no longer pending time.Sleep(time.Second * 2) + res = check.HealthCheck(ctx) if res.State == "OK" { t.Errorf("Expected health check to be ERR for a full queue") @@ -57,6 +64,7 @@ func TestIntegrationHealthCheck(t *testing.T) { err = q1.Publish("nothing here") assert.NoError(t, err) + err = q1.Publish("nothing here either") assert.NoError(t, err) // queue health pending again diff --git a/backend/redis/errors.go b/backend/redis/errors.go index aeac7d08e..d15653d9e 100644 --- a/backend/redis/errors.go +++ b/backend/redis/errors.go @@ -15,6 +15,7 @@ func IsErrConnectionFailed(err error) bool { } // go-redis has this check internally for network errors - _, ok := err.(net.Error) + _, ok := err.(net.Error) //nolint:errorlint + return ok } diff --git a/backend/redis/health_redis.go b/backend/redis/health_redis.go index fc286b42a..8a222cdac 100644 --- a/backend/redis/health_redis.go +++ b/backend/redis/health_redis.go @@ -6,8 +6,9 @@ import ( "context" "time" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/redis/go-redis/v9" + + "github.com/pace/bricks/maintenance/health/servicehealthcheck" ) // HealthCheck checks the state of a redis connection. It must not be changed @@ -19,7 +20,7 @@ type HealthCheck struct { // HealthCheck checks if the redis is healthy. If the last result is outdated, // redis is checked for writeability and readability, -// otherwise return the old result +// otherwise return the old result. func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { if time.Since(h.state.LastChecked()) <= cfg.HealthCheckResultTTL { // the last health check is not outdated, an can be reused. @@ -32,12 +33,12 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health return h.state.GetState() } // If writing worked try reading - err := h.Client.Get(ctx, cfg.HealthCheckKey).Err() - if err != nil { + if err := h.Client.Get(ctx, cfg.HealthCheckKey).Err(); err != nil { h.state.SetErrorState(err) return h.state.GetState() } // If reading an writing worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } diff --git a/backend/redis/health_redis_test.go b/backend/redis/health_redis_test.go index 2dd753da0..b0ec24da9 100644 --- a/backend/redis/health_redis_test.go +++ b/backend/redis/health_redis_test.go @@ -17,21 +17,27 @@ import ( func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) - resp := rec.Result() - defer resp.Body.Close() - return resp + + return rec.Result() } -// TestIntegrationHealthCheck tests if redis health check ist working like expected +// TestIntegrationHealthCheck tests if redis health check ist working like expected. func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + time.Sleep(time.Second) + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -39,6 +45,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data), "redis OK") { t.Errorf("Expected /health/check to return OK, got: %q", string(data[:])) } diff --git a/backend/redis/redis.go b/backend/redis/redis.go index c6ac49430..e640c1d3e 100755 --- a/backend/redis/redis.go +++ b/backend/redis/redis.go @@ -70,8 +70,7 @@ func init() { prometheus.MustRegister(paceRedisCmdDurationSeconds) // parse log config - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse redis environment: %v", err) } @@ -80,7 +79,7 @@ func init() { }) } -// Client with environment based configuration +// Client with environment based configuration. func Client(overwriteOpts ...func(*redis.Options)) *redis.Client { opts := &redis.Options{ Addr: cfg.Addrs[0], @@ -106,14 +105,14 @@ func Client(overwriteOpts ...func(*redis.Options)) *redis.Client { return CustomClient(opts) } -// CustomClient with passed configuration +// CustomClient with passed configuration. func CustomClient(opts *redis.Options) *redis.Client { log.Logger().Info().Str("addr", opts.Addr). Msg("Redis connection pool created") return redis.NewClient(opts) } -// ClusterClient with environment based configuration +// ClusterClient with environment based configuration. func ClusterClient() *redis.ClusterClient { return CustomClusterClient(&redis.ClusterOptions{ Addrs: cfg.Addrs, @@ -132,20 +131,20 @@ func ClusterClient() *redis.ClusterClient { }) } -// CustomClusterClient with passed configuration +// CustomClusterClient with passed configuration. func CustomClusterClient(opts *redis.ClusterOptions) *redis.ClusterClient { log.Logger().Info().Strs("addrs", opts.Addrs). Msg("Redis cluster connection pool created") return redis.NewClusterClient(opts) } -// WithContext adds a logging and tracing wrapper to the passed client +// WithContext adds a logging and tracing wrapper to the passed client. func WithContext(ctx context.Context, c *redis.Client) *redis.Client { c.AddHook(&logtracer{}) return c } -// WithClusterContext adds a logging and tracing wrapper to the passed client +// WithClusterContext adds a logging and tracing wrapper to the passed client. func WithClusterContext(ctx context.Context, c *redis.ClusterClient) *redis.ClusterClient { c.AddHook(&logtracer{}) return c @@ -160,11 +159,11 @@ type logtracerValues struct { span *sentry.Span } -func (lt *logtracer) DialHook(next redis.DialHook) redis.DialHook { +func (l *logtracer) DialHook(next redis.DialHook) redis.DialHook { return next } -func (lt *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { +func (l *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { startedAt := time.Now() @@ -186,8 +185,12 @@ func (lt *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { _ = next(ctx, cmd) - vals := ctx.Value(logtracerKey{}).(*logtracerValues) - le := log.Ctx(ctx).Debug().Str("cmd", cmd.Name()).Str("sentry:category", "redis") + vals, ok := ctx.Value(logtracerKey{}).(*logtracerValues) + if !ok { + vals = &logtracerValues{} + } + + le := log.Ctx(ctx).Debug().Str("cmd", cmd.Name()).Str("sentry:category", "redis") //nolint:zerologlint // add error cmdErr := cmd.Err() diff --git a/cmd/pb/main.go b/cmd/pb/main.go index 69652da49..ac5011e68 100644 --- a/cmd/pb/main.go +++ b/cmd/pb/main.go @@ -18,8 +18,8 @@ func main() { Args: cobra.MaximumNArgs(1), } addRootCommands(rootCmd) - err := rootCmd.Execute() - if err != nil { + + if err := rootCmd.Execute(); err != nil { log.Fatal(err) } } @@ -27,6 +27,7 @@ func main() { // pace ... func addRootCommands(rootCmd *cobra.Command) { var restSource string + rootCmdNew := &cobra.Command{ Use: "new NAME", Args: cobra.ExactArgs(1), @@ -67,6 +68,7 @@ func addRootCommands(rootCmd *cobra.Command) { rootCmd.AddCommand(rootCmdEdit) var runCmd string + rootCmdRun := &cobra.Command{ Use: "run NAME", Args: cobra.ExactArgs(1), @@ -81,6 +83,7 @@ func addRootCommands(rootCmd *cobra.Command) { rootCmd.AddCommand(rootCmdRun) var testGoConvey bool + rootCmdTest := &cobra.Command{ Use: "test NAME", Args: cobra.ExactArgs(1), @@ -136,6 +139,7 @@ func (e *errorDefinitionsOutputFlag) Type() string { // pace service generate ... func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { var pkgName, path, source string + cmdRest := &cobra.Command{ Use: "rest", Args: cobra.NoArgs, @@ -153,6 +157,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdRest) var commandsPath string + cmdCommands := &cobra.Command{ Use: "commands NAME", Args: cobra.ExactArgs(1), @@ -165,6 +170,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdCommands) var dockerfilePath string + cmdDockerfile := &cobra.Command{ Use: "dockerfile NAME", Args: cobra.ExactArgs(1), @@ -179,6 +185,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdDockerfile) var makefilePath string + cmdMakefile := &cobra.Command{ Use: "makefile NAME", Args: cobra.ExactArgs(1), @@ -192,6 +199,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdMakefile) var errorsDefinitionsPkgName, errorsDefinitionsPath, errorsDefinitionsSource string + errorDefinitionsOutput := goOutputFlag cmdErrorDefinitions := &cobra.Command{ Use: "error-definitions", diff --git a/grpc/client.go b/grpc/client.go index 81998f507..169ee2941 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -6,6 +6,9 @@ import ( "context" "time" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -14,10 +17,6 @@ import ( "github.com/pace/bricks/http/security" "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/log" - - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ) // Deprecated: Use NewClient instead. @@ -31,15 +30,13 @@ func Dial(addr string) (*grpc.ClientConn, error) { } func NewClient(addr string) (*grpc.ClientConn, error) { - var conn *grpc.ClientConn - clientMetrics := grpc_prometheus.NewClientMetrics() opts := []grpc_retry.CallOption{ grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), } - conn, err := grpc.NewClient(addr, + return grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithChainStreamInterceptor( grpc_opentracing.StreamClientInterceptor(), @@ -53,6 +50,7 @@ func NewClient(addr string) (*grpc.ClientConn, error) { Str("type", "stream"). Err(err). Msg("GRPC requested") + return cs, err }, ), @@ -68,23 +66,26 @@ func NewClient(addr string) (*grpc.ClientConn, error) { Str("type", "unary"). Err(err). Msg("GRPC requested") + return err }, ), ) - return conn, err } func prepareClientContext(ctx context.Context) context.Context { if loc, ok := locale.FromCtx(ctx); ok { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyLocale, loc.Serialize()) } + if token, ok := security.GetTokenFromContext(ctx); ok { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyBearerToken, token.GetValue()) } + if reqID := log.RequestIDFromContext(ctx); reqID != "" { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyRequestID, reqID) } + ctx = EncodeContextWithUTMData(ctx) if dep := middleware.ExternalDependencyContextFromContext(ctx); dep != nil { diff --git a/grpc/middleware.go b/grpc/middleware.go index 04cea44e2..1a7ac841f 100644 --- a/grpc/middleware.go +++ b/grpc/middleware.go @@ -7,8 +7,9 @@ import ( "encoding/gob" "strings" - "github.com/pace/bricks/pkg/tracking/utm" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/pkg/tracking/utm" ) const utmMetadataKey = "utm-bin" // IMPORTANT -bin post-fix allows us to send binary data via grpc metadata, otherwise it will break the protocol @@ -18,10 +19,12 @@ func ContextWithUTMFromMetadata(parentCtx context.Context, md metadata.MD) conte if len(dataSlice) == 0 { return parentCtx } + var utmData utm.UTMData if err := gob.NewDecoder(strings.NewReader(dataSlice[0])).Decode(&utmData); err != nil { return parentCtx } + return utm.ContextWithUTMData(parentCtx, utmData) } @@ -30,9 +33,11 @@ func EncodeContextWithUTMData(parentCtx context.Context) context.Context { if !exists { return parentCtx } + w := strings.Builder{} if err := gob.NewEncoder(&w).Encode(utmData); err != nil { return parentCtx } + return metadata.AppendToOutgoingContext(parentCtx, utmMetadataKey, w.String()) } diff --git a/grpc/middleware_test.go b/grpc/middleware_test.go index 474070bd3..4001bda71 100644 --- a/grpc/middleware_test.go +++ b/grpc/middleware_test.go @@ -6,9 +6,10 @@ import ( "context" "testing" - "github.com/pace/bricks/pkg/tracking/utm" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/pkg/tracking/utm" ) func TestEncodeContextWithUTMData(t *testing.T) { @@ -25,6 +26,7 @@ func TestEncodeContextWithUTMData(t *testing.T) { ctx = EncodeContextWithUTMData(ctx) md, exists := metadata.FromOutgoingContext(ctx) require.True(t, exists) + ctx2 := context.Background() ctx2 = ContextWithUTMFromMetadata(ctx2, md) utmData, exists := utm.FromContext(ctx2) diff --git a/grpc/server.go b/grpc/server.go index 99d9296d6..60073bf3d 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -9,28 +9,28 @@ import ( "strings" "time" + "github.com/caarlos0/env/v10" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/pace/bricks/http/middleware" - "github.com/pace/bricks/http/security" - "github.com/pace/bricks/locale" - "github.com/pace/bricks/maintenance/errors" - "github.com/pace/bricks/maintenance/log" - "github.com/pace/bricks/maintenance/log/hlog" "github.com/rs/xid" "github.com/rs/zerolog" zlog "github.com/rs/zerolog/log" - - "github.com/caarlos0/env/v10" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + + "github.com/pace/bricks/http/middleware" + "github.com/pace/bricks/http/security" + "github.com/pace/bricks/locale" + "github.com/pace/bricks/maintenance/errors" + "github.com/pace/bricks/maintenance/log" + "github.com/pace/bricks/maintenance/log/hlog" ) -var InternalServerError = errors.New("internal server error") +var ErrInternalServer = errors.New("internal server error") type Config struct { Address string `env:"GRPC_ADDR" envDefault:":3001"` @@ -46,18 +46,20 @@ func ListenAndServe(gs *grpc.Server) error { if err != nil { return err } + log.Logger().Info().Str("addr", listener.Addr().String()).Msg("Starting grpc server ...") - err = gs.Serve(listener) - if err != nil { + + if err := gs.Serve(listener); err != nil { return err } + return nil } func Listener() (net.Listener, error) { var cfg Config - err := env.Parse(&cfg) - if err != nil { + + if err := env.Parse(&cfg); err != nil { return nil, fmt.Errorf("failed to parse grpc server environment: %w", err) } @@ -65,6 +67,7 @@ func Listener() (net.Listener, error) { if err != nil { return nil, fmt.Errorf("unable to create grpc listener for %q: %w", cfg.Address, err) } + return tcpListener, nil } @@ -82,6 +85,7 @@ func Server(ab AuthBackend) *grpc.Server { wrappedStream := grpc_middleware.WrapServerStream(stream) wrappedStream.WrappedContext = ctx + var addr string if p, ok := peer.FromContext(ctx); ok { addr = p.Addr.String() @@ -98,12 +102,15 @@ func Server(ab AuthBackend) *grpc.Server { Str("user_agent", strings.Join(md.Get("user-agent"), ",")). Err(err). Msg("GRPC completed Stream") + return err }, func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer errors.HandleWithCtx(stream.Context(), "GRPC "+info.FullMethod) - err = InternalServerError // default in case of a panic + + err = ErrInternalServer // default in case of a panic err = handler(srv, stream) + return err }, grpc_auth.StreamServerInterceptor(ab.AuthorizeStream), @@ -131,12 +138,15 @@ func Server(ab AuthBackend) *grpc.Server { Str("user_agent", strings.Join(md.Get("user-agent"), ",")). Err(err). Msg("GRPC completed Unary") + return }, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer errors.HandleWithCtx(ctx, "GRPC "+info.FullMethod) - err = InternalServerError // default in case of a panic + + err = ErrInternalServer // default in case of a panic resp, err = handler(ctx, req) + return }, grpc_auth.UnaryServerInterceptor(ab.AuthorizeUnary), @@ -152,11 +162,14 @@ func prepareContext(ctx context.Context) (context.Context, metadata.MD) { // add request context if req_id is given var reqID xid.ID + if ri := md.Get(MetadataKeyRequestID); len(ri) > 0 { var err error + reqID, err = xid.FromString(ri[0]) if err != nil { log.Debugf("unable to parse xid from req_id: %v", err) + reqID = xid.New() } } else { diff --git a/grpc/server_test.go b/grpc/server_test.go index a2bda5f50..ebb3d69e7 100644 --- a/grpc/server_test.go +++ b/grpc/server_test.go @@ -7,12 +7,13 @@ import ( "context" "testing" - "github.com/pace/bricks/http/middleware" - "github.com/pace/bricks/locale" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/http/middleware" + "github.com/pace/bricks/locale" + "github.com/pace/bricks/maintenance/log" ) func TestPrepareContext(t *testing.T) { @@ -23,6 +24,7 @@ func TestPrepareContext(t *testing.T) { assert.NotEmpty(t, log.RequestIDFromContext(ctx0)) var buf0 bytes.Buffer + l := log.Ctx(ctx0).Output(&buf0) l.Debug().Msg("test") assert.Contains(t, buf0.String(), "{\"level\":\"debug\",\"req_id\":\""+ @@ -40,12 +42,14 @@ func TestPrepareContext(t *testing.T) { assert.Len(t, md.Get(MetadataKeyRequestID), 0) assert.Len(t, md.Get(MetadataKeyBearerToken), 0) assert.Equal(t, "c690uu0ta2rv348epm8g", log.RequestIDFromContext(ctx1)) + loc, ok := locale.FromCtx(ctx1) assert.True(t, ok) assert.Equal(t, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", loc.Language()) assert.Equal(t, "Europe/Paris", loc.Timezone()) var buf1 bytes.Buffer + l = log.Ctx(ctx1).Output(&buf1) l.Debug().Msg("test") assert.Contains(t, buf1.String(), "{\"level\":\"debug\",\"req_id\":\"c690uu0ta2rv348epm8g\",\"time\":\"") @@ -63,10 +67,12 @@ func TestPrepareContext(t *testing.T) { assert.Equal(t, "c690uu0ta2rv348epm8g", log.RequestIDFromContext(ctx1)) var buf2 bytes.Buffer + l = log.Ctx(ctx2).Output(&buf2) l.Debug().Msg("test") assert.Contains(t, buf2.String(), "{\"level\":\"debug\",\"req_id\":\"c690uu0ta2rv348epm8g\",\"time\":\"") assert.Contains(t, buf2.String(), ",\"message\":\"test\"}\n") + _, ok = locale.FromCtx(ctx2) assert.False(t, ok) diff --git a/http/jsonapi/constants.go b/http/jsonapi/constants.go index 72f160b6f..5e3129b9d 100644 --- a/http/jsonapi/constants.go +++ b/http/jsonapi/constants.go @@ -4,7 +4,7 @@ package jsonapi const ( - // StructTag annotation strings + // StructTag annotation strings. annotationJSONAPI = "jsonapi" annotationPrimary = "primary" annotationClientID = "client-id" @@ -26,33 +26,33 @@ const ( // http://jsonapi.org/format/#fetching-pagination // KeyFirstPage is the key to the links object whose value contains a link to - // the first page of data + // the first page of data. KeyFirstPage = "first" // KeyLastPage is the key to the links object whose value contains a link to - // the last page of data + // the last page of data. KeyLastPage = "last" // KeyPreviousPage is the key to the links object whose value contains a link - // to the previous page of data + // to the previous page of data. KeyPreviousPage = "prev" // KeyNextPage is the key to the links object whose value contains a link to - // the next page of data + // the next page of data. KeyNextPage = "next" // QueryParamPageNumber is a JSON API query parameter used in a page based - // pagination strategy in conjunction with QueryParamPageSize + // pagination strategy in conjunction with QueryParamPageSize. QueryParamPageNumber = "page[number]" // QueryParamPageSize is a JSON API query parameter used in a page based - // pagination strategy in conjunction with QueryParamPageNumber + // pagination strategy in conjunction with QueryParamPageNumber. QueryParamPageSize = "page[size]" // QueryParamPageOffset is a JSON API query parameter used in an offset based - // pagination strategy in conjunction with QueryParamPageLimit + // pagination strategy in conjunction with QueryParamPageLimit. QueryParamPageOffset = "page[offset]" // QueryParamPageLimit is a JSON API query parameter used in an offset based - // pagination strategy in conjunction with QueryParamPageOffset + // pagination strategy in conjunction with QueryParamPageOffset. QueryParamPageLimit = "page[limit]" // QueryParamPageCursor is a JSON API query parameter used with a cursor-based - // strategy + // strategy. QueryParamPageCursor = "page[cursor]" ) diff --git a/http/jsonapi/errors.go b/http/jsonapi/errors.go index 2520577c7..69b3ba896 100644 --- a/http/jsonapi/errors.go +++ b/http/jsonapi/errors.go @@ -14,22 +14,22 @@ import ( // For more information on JSON API error payloads, see the spec here: // http://jsonapi.org/format/#document-top-level // and here: http://jsonapi.org/format/#error-objects. -func MarshalErrors(w io.Writer, errorObjects []*ErrorObject) error { +func MarshalErrors(w io.Writer, errorObjects []*ObjectError) error { return json.NewEncoder(w).Encode(&ErrorsPayload{Errors: errorObjects}) } // ErrorsPayload is a serializer struct for representing a valid JSON API errors payload. type ErrorsPayload struct { - Errors []*ErrorObject `json:"errors"` + Errors []*ObjectError `json:"errors"` } -// ErrorObject is an `Error` implementation as well as an implementation of the JSON API error object. +// ObjectError is an `Error` implementation as well as an implementation of the JSON API error object. // // The main idea behind this struct is that you can use it directly in your code as an error type // and pass it directly to `MarshalErrors` to get a valid JSON API errors payload. // For more information on Golang errors, see: https://golang.org/pkg/errors/ // For more information on the JSON API spec's error objects, see: http://jsonapi.org/format/#error-objects -type ErrorObject struct { +type ObjectError struct { // ID is a unique identifier for this particular occurrence of a problem. ID string `json:"id,omitempty"` @@ -50,6 +50,6 @@ type ErrorObject struct { } // Error implements the `Error` interface. -func (e *ErrorObject) Error() string { +func (e *ObjectError) Error() string { return fmt.Sprintf("Error: %s %s\n", e.Title, e.Detail) } diff --git a/http/jsonapi/errors_test.go b/http/jsonapi/errors_test.go index bef0155e2..ef29d0fbc 100644 --- a/http/jsonapi/errors_test.go +++ b/http/jsonapi/errors_test.go @@ -13,7 +13,8 @@ import ( ) func TestErrorObjectWritesExpectedErrorMessage(t *testing.T) { - err := &ErrorObject{Title: "Title test.", Detail: "Detail test."} + err := &ObjectError{Title: "Title test.", Detail: "Detail test."} + var input error = err output := input.Error() @@ -26,19 +27,19 @@ func TestErrorObjectWritesExpectedErrorMessage(t *testing.T) { func TestMarshalErrorsWritesTheExpectedPayload(t *testing.T) { marshalErrorsTableTasts := []struct { Title string - In []*ErrorObject + In []*ObjectError Out map[string]interface{} }{ { Title: "TestFieldsAreSerializedAsNeeded", - In: []*ErrorObject{{ID: "0", Title: "Test title.", Detail: "Test detail", Status: "400", Code: "E1100"}}, + In: []*ObjectError{{ID: "0", Title: "Test title.", Detail: "Test detail", Status: "http.StatusBadRequest", Code: "E1100"}}, Out: map[string]interface{}{"errors": []interface{}{ - map[string]interface{}{"id": "0", "title": "Test title.", "detail": "Test detail", "status": "400", "code": "E1100"}, + map[string]interface{}{"id": "0", "title": "Test title.", "detail": "Test detail", "status": "http.StatusBadRequest", "code": "E1100"}, }}, }, { Title: "TestMetaFieldIsSerializedProperly", - In: []*ErrorObject{{Title: "Test title.", Detail: "Test detail", Meta: &map[string]interface{}{"key": "val"}}}, + In: []*ObjectError{{Title: "Test title.", Detail: "Test detail", Meta: &map[string]interface{}{"key": "val"}}}, Out: map[string]interface{}{"errors": []interface{}{ map[string]interface{}{"title": "Test title.", "detail": "Test detail", "meta": map[string]interface{}{"key": "val"}}, }}, @@ -47,11 +48,12 @@ func TestMarshalErrorsWritesTheExpectedPayload(t *testing.T) { for _, testRow := range marshalErrorsTableTasts { t.Run(testRow.Title, func(t *testing.T) { buffer, output := bytes.NewBuffer(nil), map[string]interface{}{} + var writer io.Writer = buffer _ = MarshalErrors(writer, testRow.In) - err := json.Unmarshal(buffer.Bytes(), &output) - if err != nil { + + if err := json.Unmarshal(buffer.Bytes(), &output); err != nil { t.Fatal(err) } diff --git a/http/jsonapi/generator/generate.go b/http/jsonapi/generator/generate.go index 77ccd291f..d658cc983 100644 --- a/http/jsonapi/generator/generate.go +++ b/http/jsonapi/generator/generate.go @@ -5,6 +5,7 @@ package generator import ( "fmt" "io" + "log" "net/http" "net/url" "os" @@ -29,14 +30,19 @@ type Generator struct { generatedArrayTypes map[string]bool } -func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3.Swagger, error) { // nolint: interfacer +func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3.Swagger, error) { var schema *openapi3.Swagger resp, err := http.Get(url.String()) if err != nil { return nil, err } - defer resp.Body.Close() // nolint: errcheck + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() body, err := io.ReadAll(resp.Body) if err != nil { @@ -52,9 +58,10 @@ func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3 } // BuildSource generates the go code in the specified path with specified package name -// based on the passed schema source (url or file path) +// based on the passed schema source (url or file path). func (g *Generator) BuildSource(source, packagePath, packageName string) (string, error) { loader := openapi3.NewSwaggerLoader() + var schema *openapi3.Swagger if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { @@ -69,7 +76,7 @@ func (g *Generator) BuildSource(source, packagePath, packageName string) (string } } else { // read spec - data, err := os.ReadFile(source) // nolint: gosec + data, err := os.ReadFile(source) //nolint:gosec if err != nil { return "", err } @@ -85,7 +92,7 @@ func (g *Generator) BuildSource(source, packagePath, packageName string) (string } // BuildSchema generates the go code in the specified path with specified package name -// based on the passed schema +// based on the passed schema. func (g *Generator) BuildSchema(schema *openapi3.Swagger, packagePath, packageName string) (string, error) { g.generatedTypes = make(map[string]bool) g.generatedArrayTypes = make(map[string]bool) @@ -105,8 +112,7 @@ func (g *Generator) BuildSchema(schema *openapi3.Swagger, packagePath, packageNa } for _, bf := range buildFuncs { - err := bf(schema) - if err != nil { + if err := bf(schema); err != nil { return "", err } } diff --git a/http/jsonapi/generator/generate_handler.go b/http/jsonapi/generator/generate_handler.go index 77c745e40..ff70b4cbf 100644 --- a/http/jsonapi/generator/generate_handler.go +++ b/http/jsonapi/generator/generate_handler.go @@ -27,7 +27,7 @@ const ( pkgSentry = "github.com/getsentry/sentry-go" pkgOAuth2 = "github.com/pace/bricks/http/oauth2" pkgOIDC = "github.com/pace/bricks/http/oidc" - pkgApiKey = "github.com/pace/bricks/http/security/apikey" + pkgAPIKey = "github.com/pace/bricks/http/security/apikey" //nolint:gosec pkgDecimal = "github.com/shopspring/decimal" ) @@ -40,7 +40,7 @@ const ( var noValidation = map[string]string{"valid": "-"} // List of responses that will be handled on the framework level and -// are therefore not handled by the user +// are therefore not handled by the user. var generatorResponseBlacklist = map[string]bool{ "401": true, // if no bearer token is provided "406": true, // if accept header is unacceptable @@ -55,7 +55,7 @@ var generatorResponseBlacklist = map[string]bool{ type routeGeneratorFunc func([]*route, *openapi3.Swagger) error -// BuildHandler generates the request handlers based on gorilla mux +// BuildHandler generates the request handlers based on gorilla mux. func (g *Generator) BuildHandler(schema *openapi3.Swagger) error { paths := schema.Paths // sort by key @@ -63,14 +63,15 @@ func (g *Generator) BuildHandler(schema *openapi3.Swagger) error { for k := range paths { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) var routes []*route for _, pattern := range keys { path := paths[pattern] - err := g.buildPath(pattern, path, &routes, schema.Components.SecuritySchemes) - if err != nil { + + if err := g.buildPath(pattern, path, &routes, schema.Components.SecuritySchemes); err != nil { return err } } @@ -83,8 +84,7 @@ func (g *Generator) BuildHandler(schema *openapi3.Swagger) error { g.buildRouterWithFallbackAsArg, } for _, fn := range funcs { - err := fn(routes, schema) - if err != nil { + if err := fn(routes, schema); err != nil { return err } } @@ -119,8 +119,7 @@ func (g *Generator) buildPath(pattern string, pathItem *openapi3.PathItem, route return err } - err = route.parseURL() - if err != nil { + if err := route.parseURL(); err != nil { return err } @@ -133,14 +132,12 @@ func (g *Generator) buildPath(pattern string, pathItem *openapi3.PathItem, route func (g *Generator) generateRequestResponseTypes(routes []*route, schema *openapi3.Swagger) error { for _, route := range routes { // generate ...ResponseWriter for each route - err := g.generateResponseInterface(route, schema) - if err != nil { + if err := g.generateResponseInterface(route, schema); err != nil { return err } // generate ...Request for each route - err = g.generateRequestStruct(route, schema) - if err != nil { + if err := g.generateRequestStruct(route, schema); err != nil { return err } } @@ -148,15 +145,17 @@ func (g *Generator) generateRequestResponseTypes(routes []*route, schema *openap return nil } -func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swagger) error { - var methods []jen.Code - methods = append(methods, jen.Qual("net/http", "ResponseWriter")) +func (g *Generator) generateResponseInterface(route *route, _ *openapi3.Swagger) error { + methods := []jen.Code{ + jen.Qual("net/http", "ResponseWriter"), + } // sort by key keys := make([]string, 0, len(route.operation.Responses)) for k := range route.operation.Responses { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, code := range keys { @@ -170,7 +169,7 @@ func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swa // error responses have an error message parameter codeNum, err := strconv.Atoi(code) if err != nil { - return fmt.Errorf("failed to parse response code %s: %v", code, err) + return fmt.Errorf("failed to parse response code %s: %w", code, err) } // generate method name @@ -203,6 +202,7 @@ func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swa if err != nil { return err } + method.Params(typeReference) defer func() { // defer to put methods after type @@ -255,12 +255,13 @@ func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swa return nil } -func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger) error { +func (g *Generator) generateRequestStruct(route *route, _ *openapi3.Swagger) error { body := route.operation.RequestBody - var fields []jen.Code // add http request - fields = append(fields, jen.Id("Request").Op("*").Qual("net/http", "Request").Tag(noValidation)) + fields := []jen.Code{ + jen.Id("Request").Op("*").Qual("net/http", "Request").Tag(noValidation), + } // add request type if body != nil { @@ -279,6 +280,7 @@ func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger for _, param := range route.operation.Parameters { paramName := generateParamName(param) paramStmt := jen.Id(paramName) + tags := make(map[string]string) if param.Value.Required { tags["valid"] = "required" @@ -293,8 +295,7 @@ func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger tg := g.goType(paramStmt, param.Value.Schema.Value, tags) tg.isParam = true - err := tg.invoke() - if err != nil { + if err := tg.invoke(); err != nil { return err } } @@ -309,6 +310,7 @@ func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger g.addGoDoc(route.requestType, "is a standard http.Request extended with the\n"+ "un-marshaled content object") } + g.goSource.Type().Id(route.requestType).Struct(fields...) return nil @@ -335,7 +337,7 @@ func (g *Generator) buildServiceInterface(routes []*route, schema *openapi3.Swag return nil } -func (g *Generator) buildSubServiceInterface(route *route, schema *openapi3.Swagger) error { +func (g *Generator) buildSubServiceInterface(route *route, _ *openapi3.Swagger) error { methods := make([]jen.Code, 0) if route.operation.Description != "" { @@ -343,6 +345,7 @@ func (g *Generator) buildSubServiceInterface(route *route, schema *openapi3.Swag } else { methods = append(methods, jen.Comment(fmt.Sprintf("%s %s", route.serviceFunc, route.operation.Summary))) } + methods = append(methods, jen.Id(route.serviceFunc).Params( jen.Qual("context", "Context"), jen.Id(route.responseType), @@ -360,7 +363,9 @@ func (g *Generator) buildRouter(routes []*route, schema *openapi3.Swagger) error if err != nil { return nil } + g.addGoDoc("Router", "implements: "+schema.Info.Title+"\n\n"+schema.Info.Description) + serviceInterfaceVariable := jen.Id("service").Interface() if hasSecuritySchema(schema) { g.goSource.Func().Id("Router").Params( @@ -369,6 +374,7 @@ func (g *Generator) buildRouter(routes []*route, schema *openapi3.Swagger) error g.goSource.Func().Id("Router").Params( serviceInterfaceVariable).Op("*").Qual(pkgGorillaMux, "Router").Block(routerBody...) } + return nil } @@ -377,7 +383,9 @@ func (g *Generator) buildRouterWithFallbackAsArg(routes []*route, schema *openap if err != nil { return nil } + g.addGoDoc("Router", "implements: "+schema.Info.Title+"\n\n"+schema.Info.Description) + serviceInterfaceVariable := jen.Id("service").Interface() if hasSecuritySchema(schema) { g.goSource.Func().Id("RouterWithFallback").Params( @@ -386,6 +394,7 @@ func (g *Generator) buildRouterWithFallbackAsArg(routes []*route, schema *openap g.goSource.Func().Id("RouterWithFallback").Params( serviceInterfaceVariable, jen.Id("fallback").Qual("net/http", "Handler")).Op("*").Qual(pkgGorillaMux, "Router").Block(routerBody...) } + return nil } @@ -401,15 +410,18 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger // add all route handlers for i := 0; i < len(sortableRoutes); i++ { route := sortableRoutes[i] + var routeCallParams *jen.Statement if needsSecurity { routeCallParams = jen.List(jen.Id("service"), jen.Id("authBackend")) } else { routeCallParams = jen.List(jen.Id("service")) } + primaryHandler := jen.Id(route.handler).Call(routeCallParams) fallbackHandler := jen.Id(fallbackName) ifElse := make([]jen.Code, 0) + for _, handler := range []jen.Code{primaryHandler, fallbackHandler} { block := jen.Return(handler) ifElse = append(ifElse, block) @@ -431,6 +443,7 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger } else { callParams = jen.List(jen.Id("service").Id("interface{}"), fallback) } + helper := jen.Func().Id(generateHandlerTypeAssertionHelperName(route.handler)). Params(callParams).Qual("net/http", "Handler").Block(implGuard).Line().Line() @@ -444,7 +457,9 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi3.Swagger, fallback jen.Code) ([]jen.Code, error) { needsSecurity := hasSecuritySchema(schema) startInd := 0 - var routeStmts []jen.Code + + var routeStmts []jen.Code //nolint:prealloc + if needsSecurity { startInd++ routeStmts = make([]jen.Code, 2, (len(routes)+2)*len(schema.Servers)+2) @@ -453,7 +468,9 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi for name := range schema.Components.SecuritySchemes { names = append(names, name) } + sort.Stable(sort.StringSlice(names)) + caser := cases.Title(language.Und, cases.NoLower) for _, name := range names { routeStmts = append(routeStmts, jen.Id("authBackend").Dot("Init"+caser.String(name)).Call(jen.Id("cfg"+caser.String(name)))) @@ -466,16 +483,20 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi // Note: we don't restrict host, scheme and port to ease development pathsIdx := make(map[string]struct{}) + var paths []string + for _, server := range schema.Servers { - serverUrl, err := url.Parse(server.URL) + serverURL, err := url.Parse(server.URL) if err != nil { return nil, err } - if _, ok := pathsIdx[serverUrl.Path]; !ok { - paths = append(paths, serverUrl.Path) + + if _, ok := pathsIdx[serverURL.Path]; !ok { + paths = append(paths, serverURL.Path) } - pathsIdx[serverUrl.Path] = struct{}{} + + pathsIdx[serverURL.Path] = struct{}{} } // but generate subrouters for each server @@ -494,12 +515,14 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi // add all route handlers for i := 0; i < len(sortableRoutes); i++ { route := sortableRoutes[i] + var routeCallParams *jen.Statement if needsSecurity { routeCallParams = jen.List(jen.Id("service"), fallback, jen.Id("authBackend")) } else { routeCallParams = jen.List(jen.Id("service"), fallback) } + helper := jen.Id(generateHandlerTypeAssertionHelperName(route.handler)).Call(routeCallParams) routeStmt := jen.Id(subrouterID).Dot("Methods").Call(jen.Lit(route.method)). Dot("Path").Call(jen.Lit(route.url.Path)) @@ -510,6 +533,7 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi if len(value) != 1 { panic("query paths can only handle one query parameter with the same name!") } + routeStmt.Dot("Queries").Call(jen.Lit(key), jen.Lit(value[0])) } } @@ -520,7 +544,6 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi routeStmt.Dot("Handler").Call(helper) routeStmts = append(routeStmts, routeStmt) - } } @@ -530,7 +553,7 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi return routeStmts, nil } -func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern string, pathItem *openapi3.PathItem, secSchemes map[string]*openapi3.SecuritySchemeRef) (*route, error) { +func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern string, _ *openapi3.PathItem, secSchemes map[string]*openapi3.SecuritySchemeRef) (*route, error) { needsSecurity := len(secSchemes) > 0 route := &route{ method: strings.ToUpper(method), @@ -545,11 +568,14 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // use OperationID for go function names or generate the name caser := cases.Title(language.Und, cases.NoLower) + oid := caser.String(op.OperationID) if oid == "" { log.Warnf("Note: Avoid automatic method name generation for path (use OperationID): %s", pattern) + oid = generateName(method, op, pattern) } + handler := oid + "Handler" route.handler = handler route.serviceFunc = oid @@ -559,6 +585,7 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // check if handler has request body var requestBody bool + if body := op.RequestBody; body != nil { if mt := body.Value.Content.Get(jsonapiContent); mt != nil { requestBody = true @@ -567,25 +594,31 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // generate handler function gen := g // generator is used less frequent then the jen group, make available with longer name + var auth *jen.Group + if needsSecurity { if op.Security != nil { var err error + auth, err = generateAuthorization(op, secSchemes) if err != nil { return nil, err } } } + g.addGoDoc(handler, fmt.Sprintf(`handles request/response marshaling and validation for %s %s`, method, pattern)) + var params *jen.Statement if needsSecurity { params = jen.List(jen.Id("service").Id(generateSubServiceName(route.handler)), jen.Id("authBackend").Id(authBackendInterface)) } else { params = jen.List(jen.Id("service").Id(generateSubServiceName(route.handler))) } + g.goSource.Func().Id(handler).Params(params).Qual("net/http", "Handler").Block( jen.Return().Qual("net/http", "HandlerFunc").Call( jen.Func().Params( @@ -625,14 +658,17 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // vars in case parameters are given g.Line().Comment("Scan and validate incoming request parameters") + if len(route.operation.Parameters) > 0 { // path parameters need the vars needVars := false + for _, param := range route.operation.Parameters { if param.Value.In == "path" { needVars = true } } + if needVars { g.Id("vars").Op(":=").Qual(pkgGorillaMux, "Vars").Call(jen.Id("r")) } @@ -702,7 +738,9 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // otherwise directly call the service if requestBody { g.Line().Comment("Unmarshal the service request body") + isArray := false + mt := op.RequestBody.Value.Content.Get(jsonapiContent) if mt != nil { data := mt.Schema.Value.Properties["data"] @@ -712,6 +750,7 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern } } } + if isArray { typeName := nameFromSchemaRef(mt.Schema.Value.Properties["data"].Value.Items) g.List(jen.Id("ok"), jen.Id("data")).Op(":="). @@ -753,17 +792,21 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern func generateAuthorization(op *openapi3.Operation, secSchemes map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { req := *op.Security r := &jen.Group{} + if len(req[0]) == 0 { return r, nil } multipleSecSchemes := len(req[0]) > 1 + var err error + if multipleSecSchemes { r, err = generateAuthorizationForMultipleSecSchemas(op, secSchemes) } else { r, err = generateAuthorizationForSingleSecSchema(op, secSchemes) } + if err != nil { return nil, err } @@ -774,10 +817,13 @@ func generateAuthorization(op *openapi3.Operation, secSchemes map[string]*openap func generateAuthorizationForSingleSecSchema(op *openapi3.Operation, schemas map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { req := *op.Security r := &jen.Group{} + if len(req[0]) == 0 { return nil, nil } + caser := cases.Title(language.Und, cases.NoLower) + for name, secConfig := range (*op.Security)[0] { securityScheme := schemas[name] switch securityScheme.Value.Type { @@ -791,22 +837,30 @@ func generateAuthorizationForSingleSecSchema(op *openapi3.Operation, schemas map if len(secConfig) > 0 { return nil, fmt.Errorf("security config for api key authorization needs %d values but had: %d", 0, len(secConfig)) } + r.Line().List(jen.Id("ctx"), jen.Id("ok")).Op(":=").Id("authBackend."+authFuncPrefix+caser.String(name)).Call(jen.Id("r"), jen.Id("w")) default: return nil, fmt.Errorf("security Scheme of type %q is not suppported", securityScheme.Value.Type) } } + r.Line().If(jen.Op("!").Id("ok")).Block(jen.Return()) + return r, nil } func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchemes map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { - var orderedSec [][]string + orderedSec := make([][]string, len((*op.Security)[0])) + i := 0 + // Security Schemes are sorted for a reliable order of the code for name, val := range (*op.Security)[0] { vals := []string{name} - orderedSec = append(orderedSec, append(vals, val...)) + orderedSec[i] = append(vals, val...) + + i++ } + sort.Slice(orderedSec, func(i, j int) bool { return orderedSec[i][0] < orderedSec[j][0] }) @@ -819,11 +873,13 @@ func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchem caser := cases.Title(language.Und, cases.NoLower) r.Line().Var().Id("ok").Id("bool") + for _, val := range orderedSec { name := val[0] securityScheme := secSchemes[name] innerBlock := &jen.Group{} innerBlock.Line().List(jen.Id("ctx"), jen.Id("ok")).Op("=").Id("authBackend." + authFuncPrefix + caser.String(name)) + switch securityScheme.Value.Type { case "oauth2", "openIdConnect": if len(val) >= 2 { @@ -835,25 +891,31 @@ func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchem if len(val) > 1 { return nil, fmt.Errorf("security config for api key authorization needs %d values but had: %d", 0, len(val)) } + innerBlock.Call(jen.Id("r"), jen.Id("w")) default: return nil, fmt.Errorf("security Scheme of type %q is not suppported", securityScheme.Value.Type) } + innerBlock.Line().If(jen.Op("!").Id("ok")).Block(jen.Return()) r.Line().If(jen.Id("authBackend." + authCanAuthFuncPrefix + caser.String(name))).Call(jen.Id("r")).Block(innerBlock).Else() } + r.Block(last) + return r, nil } var asciiName = regexp.MustCompile("([^a-zA-Z]+)") -func generateName(method string, op *openapi3.Operation, pattern string) string { +func generateName(method string, _ *openapi3.Operation, pattern string) string { name := method parts := strings.Split(asciiName.ReplaceAllString(pattern, "/"), "/") + for _, part := range parts { name += goNameHelper(part) } + return goNameHelper(name) } @@ -862,6 +924,7 @@ func generateMethodName(description string) string { for i := 0; i < len(parts); i++ { parts[i] = goNameHelper(parts[i]) } + return goNameHelper(strings.Join(parts, "")) } diff --git a/http/jsonapi/generator/generate_helper.go b/http/jsonapi/generator/generate_helper.go index 9cf93b1ff..71c3ec15f 100644 --- a/http/jsonapi/generator/generate_helper.go +++ b/http/jsonapi/generator/generate_helper.go @@ -22,7 +22,7 @@ func (g *Generator) addGoDoc(typeName, description string) { } } -func (g *Generator) goType(stmt *jen.Statement, schema *openapi3.Schema, tags map[string]string) *typeGenerator { // nolint: gocyclo +func (g *Generator) goType(stmt *jen.Statement, schema *openapi3.Schema, tags map[string]string) *typeGenerator { return &typeGenerator{ g: g, stmt: stmt, @@ -39,7 +39,7 @@ type typeGenerator struct { isParam bool } -func (g *typeGenerator) invoke() error { // nolint: gocyclo +func (g *typeGenerator) invoke() error { switch g.schema.Type { case "string": switch g.schema.Format { @@ -61,6 +61,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "date": addValidator(g.tags, "time(2006-01-02)") + if g.isParam { g.stmt.Qual("time", "Time") } else { @@ -68,6 +69,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "uuid": addValidator(g.tags, "uuid") + if g.schema.Nullable { g.stmt.Op("*").String() } else { @@ -75,6 +77,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "decimal": addValidator(g.tags, "matches(^(\\d*\\.)?\\d+$)") + if g.isParam { g.stmt.Qual(pkgDecimal, "Decimal") } else { @@ -89,6 +92,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "integer": removeOmitempty(g.tags) + switch g.schema.Format { case "int32": if g.schema.Nullable { @@ -114,6 +118,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "float": removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Float32() } else { @@ -123,6 +128,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo fallthrough default: removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Float64() } else { @@ -131,15 +137,16 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "boolean": removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Bool() } else { g.stmt.Bool() } - case "array": // nolint: goconst + case "array": removeOmitempty(g.tags) - err := g.g.goType(g.stmt.Index(), g.schema.Items.Value, g.tags).invoke() - if err != nil { + + if err := g.g.goType(g.stmt.Index(), g.schema.Items.Value, g.tags).invoke(); err != nil { return err } default: @@ -156,7 +163,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo // in case the field/value is optional // an empty value needs to be added to the enum validator if hasValidator(g.tags, "optional") { - strs = append(strs, "") + strs = append(strs, "") //nolint:makezero } addValidator(g.tags, fmt.Sprintf("in(%v)", strings.Join(strs, "|"))) @@ -182,6 +189,7 @@ func addValidator(tags map[string]string, validator string) { if cur != "" { validator = tags["valid"] + "," + validator } + tags["valid"] = validator } @@ -190,6 +198,7 @@ func hasValidator(tags map[string]string, validator string) bool { if !ok { return false } + validators := strings.Split(validatorCfg, ",") for _, v := range validators { if strings.HasPrefix(v, validator) { @@ -207,6 +216,7 @@ func goNameHelper(name string) string { name = caser.String(name) name = strings.Replace(name, "Url", "URL", -1) name = idRegex.ReplaceAllString(name, "ID") + return name } @@ -215,5 +225,6 @@ func nameFromSchemaRef(ref *openapi3.SchemaRef) string { if name == "." { return "" } + return name } diff --git a/http/jsonapi/generator/generate_security.go b/http/jsonapi/generator/generate_security.go index b7f4ba045..c5e1db974 100644 --- a/http/jsonapi/generator/generate_security.go +++ b/http/jsonapi/generator/generate_security.go @@ -27,17 +27,25 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro if !hasSecuritySchema(schema) { return nil } + securitySchemes := schema.Components.SecuritySchemes // r contains the methods for the security interface r := &jen.Group{} // Because the order of the values while iterating over a map is randomized the generated result can only be tested if the keys are sorted - var keys []string + keys := make([]string, len(securitySchemes)) + i := 0 + for k := range securitySchemes { - keys = append(keys, k) + keys[i] = k + + i++ } + sort.Stable(sort.StringSlice(keys)) + hasDuplicatedSecuritySchema := false + for _, pathItem := range schema.Paths { for _, op := range pathItem.Operations() { if op.Security != nil { @@ -47,9 +55,12 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro } caser := cases.Title(language.Und, cases.NoLower) + for _, name := range keys { value := securitySchemes[name] + r.Line().Id(authFuncPrefix + caser.String(name)) + switch value.Value.Type { case "oauth2": r.Params(jen.Id("r").Id("*http.Request"), jen.Id("w").Id("http.ResponseWriter"), jen.Id("scope").String()).Params(jen.Id("context.Context"), jen.Id("bool")) @@ -59,7 +70,7 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgOIDC, "Config")) case "apiKey": r.Params(jen.Id("r").Id("*http.Request"), jen.Id("w").Id("http.ResponseWriter")).Params(jen.Id("context.Context"), jen.Id("bool")) - r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgApiKey, "Config")) + r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgAPIKey, "Config")) default: return errors.New("security schema type not supported: " + value.Value.Type) } @@ -70,26 +81,35 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro } g.goSource.Type().Id(authBackendInterface).Interface(r) + return nil } -// BuildSecurityConfigs creates structs with the config of each security schema +// BuildSecurityConfigs creates structs with the config of each security schema. func (g *Generator) buildSecurityConfigs(schema *openapi3.Swagger) error { if !hasSecuritySchema(schema) { return nil } + securitySchemes := schema.Components.SecuritySchemes // Because the order of the values while iterating over a map is randomized the generated result can only be tested if the keys are sorted - var keys []string + keys := make([]string, len(securitySchemes)) + i := 0 + for k := range securitySchemes { - keys = append(keys, k) + keys[i] = k + + i++ } + sort.Stable(sort.StringSlice(keys)) for _, name := range keys { value := securitySchemes[name] instanceVal := jen.Dict{} + var pkgName string + switch value.Value.Type { case "oauth2": pkgName = pkgOAuth2 @@ -112,40 +132,46 @@ func (g *Generator) buildSecurityConfigs(schema *openapi3.Swagger) error { case "openIdConnect": pkgName = pkgOIDC instanceVal[jen.Id("Description")] = jen.Lit(value.Value.Description) + if e, ok := value.Value.Extensions["openIdConnectUrl"]; ok { var url string if data, ok := e.(json.RawMessage); ok { - err := json.Unmarshal(data, &url) - if err != nil { + if err := json.Unmarshal(data, &url); err != nil { return err } - instanceVal[jen.Id("OpenIdConnectURL")] = jen.Lit(url) + + instanceVal[jen.Id("OpenIDConnectURL")] = jen.Lit(url) } } case "apiKey": - pkgName = pkgApiKey + pkgName = pkgAPIKey instanceVal[jen.Id("Description")] = jen.Lit(value.Value.Description) instanceVal[jen.Id("In")] = jen.Lit(value.Value.In) instanceVal[jen.Id("Name")] = jen.Lit(value.Value.Name) default: return errors.New("security schema type not supported: " + value.Value.Type) } + caser := cases.Title(language.Und, cases.NoLower) g.goSource.Var().Id("cfg"+caser.String(name)).Op("=").Op("&").Qual(pkgName, "Config").Values(instanceVal) } + return nil } -// getValuesFromFlow puts the values from the OAuth Flow in a jen.Dict to generate it +// getValuesFromFlow puts the values from the OAuth Flow in a jen.Dict to generate it. func getValuesFromFlow(flow *openapi3.OAuthFlow) jen.Dict { r := jen.Dict{} r[jen.Id("AuthorizationURL")] = jen.Lit(flow.AuthorizationURL) r[jen.Id("TokenURL")] = jen.Lit(flow.TokenURL) r[jen.Id("RefreshURL")] = jen.Lit(flow.RefreshURL) + scopes := jen.Dict{} for scope, descr := range flow.Scopes { scopes[jen.Lit(scope)] = jen.Lit(descr) } + r[jen.Id("Scopes")] = jen.Map(jen.String()).String().Values(scopes) + return r } diff --git a/http/jsonapi/generator/generate_test.go b/http/jsonapi/generator/generate_test.go index 56b35dd03..54943f41c 100644 --- a/http/jsonapi/generator/generate_test.go +++ b/http/jsonapi/generator/generate_test.go @@ -30,20 +30,28 @@ func TestGenerator(t *testing.T) { } g := Generator{} + result, err := g.BuildSource(testCase.source, filepath.Dir(testCase.pkg), filepath.Base(testCase.pkg)) if err != nil { t.Fatal(err) } + if os.Getenv("PACE_TEST_GENERATOR_WRITE") != "" { + if err := os.MkdirAll("testout", 0o750); err != nil { + t.Fatal(err) + } + f, err := os.Create(fmt.Sprintf("testout/test.%s.out.go", testCase.pkg)) if err != nil { t.Fatal(err) } + _, err = f.WriteString(result) if err != nil { t.Fatal(err) } } + if string(expected[:]) != result { diff := difflib.UnifiedDiff{ A: difflib.SplitLines(string(expected[:])), diff --git a/http/jsonapi/generator/generate_types.go b/http/jsonapi/generator/generate_types.go index 1b0b95ef9..bc8a32f8f 100644 --- a/http/jsonapi/generator/generate_types.go +++ b/http/jsonapi/generator/generate_types.go @@ -17,7 +17,7 @@ const ( pkgJSONAPI = "github.com/pace/bricks/http/jsonapi" ) -// BuildTypes transforms all component schemas into go types +// BuildTypes transforms all component schemas into go types. func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { schemas := schema.Components.Schemas @@ -26,6 +26,7 @@ func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { for k := range schemas { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, name := range keys { @@ -43,8 +44,7 @@ func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { continue } - err := g.buildType(name, t, schemaType, make(map[string]string), true) - if err != nil { + if err := g.buildType(name, t, schemaType, make(map[string]string), true); err != nil { return err } // document type @@ -55,65 +55,77 @@ func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { return nil } -func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openapi3.SchemaRef, tags map[string]string, ptr bool) error { // nolint: gocyclo +func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openapi3.SchemaRef, tags map[string]string, ptr bool) error { name := nameFromSchemaRef(schema) val := schema.Value switch val.Type { - case "array": // nolint: goconst + case "array": if schema.Ref != "" { // handle references stmt.Id(name) return nil } g.generatedArrayTypes[prefix] = true + return g.buildType(prefix, stmt.Index(), val.Items, tags, ptr) - case "object": // nolint: goconst + case "object": if schema.Ref != "" { // handle references if ptr { stmt.Op("*").Id(name) } else { stmt.Id(name) } + return nil } + if val.AdditionalPropertiesAllowed != nil && *val.AdditionalPropertiesAllowed { if len(val.Properties) > 0 { log.Warnf("%s properties are ignored. Only %s of type map[string]interface{} is generated ", prefix, prefix) } + stmt.Map(jen.String()).Interface() + return nil } + if val.AdditionalProperties != nil { if len(val.Properties) > 0 { log.Warnf("%s properties are ignored. Only %s of type map[string]type is generated ", prefix, prefix) } + stmt.Map(jen.String()) + if val.AdditionalProperties.Ref != "" { stmt.Op("*").Id(nameFromSchemaRef(val.AdditionalProperties)) return nil } + if val.AdditionalProperties.Value != nil { - err := g.goType(stmt, val.AdditionalProperties.Value, make(map[string]string)).invoke() - if err != nil { + if err := g.goType(stmt, val.AdditionalProperties.Value, make(map[string]string)).invoke(); err != nil { return err } } + return nil } if data := val.Properties["data"]; data != nil { if data.Ref != "" { return g.buildType(prefix+"Ref", stmt, data, make(map[string]string), ptr) - } else if data.Value.Type == "array" { // nolint: goconst + } else if data.Value.Type == "array" { item := prefix + "Item" if ptr { stmt.Index().Op("*").Id(item) } else { stmt.Index().Id(item) } + g.addGoDoc(item, data.Value.Description) + itemStmt := g.goSource.Type().Id(item) + return g.structJSONAPI(prefix, itemStmt, data.Value.Items.Value) } else if data.Value.Type == "object" { // This ensures that the code does only treat objects with data properties that // are objects themselves as legitimate JSONAPI struct, otherwise we want them to be treated as simple data objects. @@ -141,11 +153,11 @@ func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openap if len(val.AllOf)+len(val.AnyOf)+len(val.OneOf) > 0 { log.Warnf("Can't generate allOf, anyOf and oneOf for type %q", prefix) stmt.Qual("encoding/json", "RawMessage") + return nil } - err := g.goType(stmt, val, tags).invoke() - if err != nil { + if err := g.goType(stmt, val, tags).invoke(); err != nil { return err } } @@ -172,6 +184,7 @@ func (g *Generator) buildTypeStruct(name string, stmt *jen.Statement, schema *op } else { stmt.Id(name) } + return nil } @@ -181,7 +194,7 @@ func (g *Generator) buildTypeStruct(name string, stmt *jen.Statement, schema *op } // references the type from the schema or generates a new type (inline) -// and returns +// and returns. func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3.SchemaRef, noPtr bool) (jen.Code, error) { // handle references if schema.Ref != "" { @@ -191,11 +204,12 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. // in case the type referenced is defined already directly reference it sv := schema.Value - if sv.Type == "object" && sv.Properties["data"] != nil && sv.Properties["data"].Ref != "" { // nolint: goconst + if sv.Type == "object" && sv.Properties["data"] != nil && sv.Properties["data"].Ref != "" { id := nameFromSchemaRef(schema.Value.Properties["data"]) if g.generatedArrayTypes[id] { return jen.Id(id), nil } + if noPtr { return jen.Id(id), nil } @@ -207,11 +221,12 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. t, ok := g.newType(fallbackName) if ok { g.addGoDoc(fallbackName, schema.Value.Description) - err := g.buildType(fallbackName, g.goSource.Add(t), schema, make(map[string]string), true) - if err != nil { + + if err := g.buildType(fallbackName, g.goSource.Add(t), schema, make(map[string]string), true); err != nil { return nil, err } } + if noPtr { return jen.Id(fallbackName), nil } @@ -219,7 +234,7 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. return jen.Op("*").Id(fallbackName), nil } -func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *openapi3.Schema) error { // nolint: gocyclo +func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *openapi3.Schema) error { var fields []jen.Code propID := schema.Properties["id"] @@ -234,6 +249,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, id) // add attributes @@ -242,6 +258,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, attrFields...) } @@ -249,10 +266,11 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op links := schema.Properties["links"] if links != nil { linksAttr := jen.Id("Links") - err := g.buildTypeStruct(prefix+"Links", linksAttr, links.Value, true) - if err != nil { + + if err := g.buildTypeStruct(prefix+"Links", linksAttr, links.Value, true); err != nil { return err } + fields = append(fields, linksAttr) } @@ -261,12 +279,13 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if meta != nil { metaAttr := jen.Id("Meta") defer func() { - err := g.buildTypeStruct(prefix+"Meta", metaAttr, meta.Value, true) - if err != nil { + if err := g.buildTypeStruct(prefix+"Meta", metaAttr, meta.Value, true); err != nil { log.Fatal(err) } + metaAttr.Comment("Resource meta data (json:api meta)") }() + fields = append(fields, metaAttr) } @@ -276,6 +295,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, relFields...) } @@ -283,10 +303,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op // generate meta function if any if meta != nil { - err := g.generateJSONAPIMeta(prefix, stmt, meta.Value) - if err != nil { - return err - } + g.generateJSONAPIMeta(prefix, stmt, meta.Value) } return nil @@ -295,30 +312,35 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op func (g *Generator) generateAttrField(prefix, name string, schema *openapi3.SchemaRef, tags map[string]string) (*jen.Statement, error) { field := jen.Id(goNameHelper(name)) - err := g.buildType(prefix+goNameHelper(name), field, schema, tags, false) - if err != nil { + if err := g.buildType(prefix+goNameHelper(name), field, schema, tags, false); err != nil { return nil, err } + field.Tag(tags) + if schema.Ref == "" { g.commentOrExample(field, schema.Value) } + return field, nil } -func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, jsonAPIObject bool) ([]jen.Code, error) { +func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, _ bool) ([]jen.Code, error) { // sort by key keys := make([]string, 0, len(schema.Properties)) for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) - var fields []jen.Code + fields := make([]jen.Code, 0) + for _, attrName := range keys { attrSchema := schema.Properties[attrName] tags := make(map[string]string) addJSONAPITags(tags, "attr", attrName) + if attrSchema.Value.AdditionalPropertiesAllowed != nil && *attrSchema.Value.AdditionalPropertiesAllowed || attrSchema.Value.AdditionalProperties != nil { @@ -332,20 +354,24 @@ func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, if err != nil { return nil, err } + fields = append(fields, field) } + return fields, nil } -func (g *Generator) generateStructRelationships(prefix string, schema *openapi3.Schema, jsonAPI bool) ([]jen.Code, error) { +func (g *Generator) generateStructRelationships(prefix string, schema *openapi3.Schema, _ bool) ([]jen.Code, error) { // sort by key keys := make([]string, 0, len(schema.Properties)) for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) - var relationships []jen.Code + relationships := make([]jen.Code, 0) + for _, relName := range keys { relSchema := schema.Properties[relName] tags := make(map[string]string) @@ -363,22 +389,31 @@ func (g *Generator) generateStructRelationships(prefix string, schema *openapi3. switch data.Value.Type { // case array = one-to-many - case "array": // nolint: goconst - name := data.Value.Items.Value.Properties["type"].Value.Enum[0].(string) + case "array": + name, ok := data.Value.Items.Value.Properties["type"].Value.Enum[0].(string) + if !ok { + return nil, fmt.Errorf("expected as string got %T", data.Value.Items.Value.Properties["type"].Value.Enum[0]) + } + rel.Index().Op("*").Id(goNameHelper(name)).Tag(tags) // case object = belongs-to - case "object": // nolint: goconst - name := data.Value.Properties["type"].Value.Enum[0].(string) + case "object": + name, ok := data.Value.Properties["type"].Value.Enum[0].(string) + if !ok { + return nil, fmt.Errorf("expected as string got %T", data.Value.Properties["type"].Value.Enum[0]) + } + rel.Op("*").Id(goNameHelper(name)).Tag(tags) } relationships = append(relationships, rel) } + return relationships, nil } -// generateJSONAPIMeta generates a function that implements JSONAPIMeta -func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, schema *openapi3.Schema) error { +// generateJSONAPIMeta generates a function that implements JSONAPIMeta. +func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, schema *openapi3.Schema) { stmt.Line().Comment("JSONAPIMeta implements the meta data API for json:api").Line(). Func().Params(jen.Id("r").Op("*").Id(typeName)).Id("JSONAPIMeta").Params().Op("*").Qual(pkgJSONAPI, "Meta").BlockFunc( func(g *jen.Group) { @@ -391,6 +426,7 @@ func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, sc for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, attrName := range keys { @@ -399,8 +435,6 @@ func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, sc g.Return(jen.Op("&").Id("meta")) }) - - return nil } func (g *Generator) generateIDField(idType, objectType *openapi3.Schema) (*jen.Statement, error) { @@ -408,29 +442,34 @@ func (g *Generator) generateIDField(idType, objectType *openapi3.Schema) (*jen.S tags := map[string]string{ "jsonapi": fmt.Sprintf("primary,%s,omitempty", objectType.Enum[0]), } - err := g.goType(id, idType, tags).invoke() - if err != nil { + + if err := g.goType(id, idType, tags).invoke(); err != nil { return nil, err } + addValidator(tags, "optional") id.Tag(tags) g.commentOrExample(id, idType) + return id, nil } // newType generates a new type only if it was not generated yet. -// returns nil, false if type already exists +// returns nil, false if type already exists. func (g *Generator) newType(name string) (*jen.Statement, bool) { if g.generatedTypes[name] { return nil, false } + g.generatedTypes[name] = true + return jen.Type().Id(name), true } func addRequiredOptionalTag(tags map[string]string, name string, schema *openapi3.Schema) { // check if field is required isRequired := false + for _, required := range schema.Required { if required == name { isRequired = true @@ -455,6 +494,7 @@ func removeOmitempty(tags map[string]string) { if v, ok := tags["jsonapi"]; ok { tags["jsonapi"] = strings.ReplaceAll(v, ",omitempty", "") } + if v, ok := tags["json"]; ok { tags["json"] = strings.ReplaceAll(v, ",omitempty", "") } diff --git a/http/jsonapi/generator/internal/fueling/fueling_test.go b/http/jsonapi/generator/internal/fueling/fueling_test.go index d97d46708..4879b5924 100644 --- a/http/jsonapi/generator/internal/fueling/fueling_test.go +++ b/http/jsonapi/generator/internal/fueling/fueling_test.go @@ -3,6 +3,7 @@ package fueling import ( "context" "io" + "net/http" "net/http/httptest" "strings" "testing" @@ -36,7 +37,7 @@ func (t *testService) WaitOnPumpStatusChange(context.Context, WaitOnPumpStatusCh func TestErrorReporting(t *testing.T) { r := Router(&testService{t}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/fueling/beta/gas-stations/d7101f72-a672-453c-9d36-d5809ef0ded6/approaching", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/fueling/beta/gas-stations/d7101f72-a672-453c-9d36-d5809ef0ded6/approaching", strings.NewReader(`{ "data": { "type": "approaching", "id": "c3f037ea-492e-4033-9b4b-4efc7beca16c", @@ -52,9 +53,14 @@ func TestErrorReporting(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + b, _ := io.ReadAll(resp.Body) - require.Equalf(t, 422, resp.StatusCode, "expected 422 got: %s", string(b)) - assert.Contains(t, string(b), `can't parse content: got value \"47.8\" expected type float32: Invalid type provided`) + require.Equalf(t, http.StatusUnprocessableEntity, resp.StatusCode, "expected 422 got: %s", string(b)) + assert.Contains(t, string(b), `can't parse content: got value \"47.8\" expected type float32: invalid type provided`) } diff --git a/http/jsonapi/generator/internal/pay/open-api_test.go b/http/jsonapi/generator/internal/pay/open-api_test.go index 8f0a30bfe..058acf6da 100644 --- a/http/jsonapi/generator/internal/pay/open-api_test.go +++ b/http/jsonapi/generator/internal/pay/open-api_test.go @@ -144,7 +144,7 @@ var cfgOAuth2 = &oauth2.Config{ } var cfgOpenID = &oidc.Config{ Description: "", - OpenIdConnectURL: "https://example.com/.well-known/openid-configuration", + OpenIDConnectURL: "https://example.com/.well-known/openid-configuration", } var cfgProfileKey = &apikey.Config{ Description: "prefix with \"Bearer \"", diff --git a/http/jsonapi/generator/internal/pay/pay_test.go b/http/jsonapi/generator/internal/pay/pay_test.go index 19aa98c9a..35e8ece40 100644 --- a/http/jsonapi/generator/internal/pay/pay_test.go +++ b/http/jsonapi/generator/internal/pay/pay_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" "github.com/pace/bricks/http/jsonapi" "github.com/pace/bricks/http/jsonapi/runtime" @@ -31,6 +32,7 @@ func (s *testService) CreatePaymentMethodSEPA(ctx context.Context, w CreatePayme if str := "Jon"; r.Content.FirstName != str { s.t.Errorf("expected FirstName to be %q, got %q", str, r.Content.FirstName) } + if str := "Haid-und-Neu-Str."; r.Content.Address.Street != str { s.t.Errorf("expected Address.Street to be %q, got %q", str, r.Content.Address.Street) } @@ -76,6 +78,7 @@ func (s *testService) ProcessPayment(ctx context.Context, w ProcessPaymentRespon if r.Content.PriceIncludingVAT.String() != "69.34" { s.t.Errorf(`expected priceIncludingVAT "69.34", got %q`, r.Content.PriceIncludingVAT) } + amount := decimal.RequireFromString("11.07") rate := decimal.RequireFromString("19.0") priceWithVat := decimal.RequireFromString("69.34") @@ -139,7 +142,7 @@ func (s testAuthBackend) InitProfileKey(cfgProfileKey *apikey.Config) { func TestHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/pay/beta/payment-methods/sepa-direct-debit", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/pay/beta/payment-methods/sepa-direct-debit", strings.NewReader(`{ "data": { "id": "2a1319c3-c136-495d-b59a-47b3246d08af", "type": "paymentMethod", @@ -164,14 +167,20 @@ func TestHandler(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() - if resp.StatusCode != 201 { + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + + if resp.StatusCode != http.StatusCreated { t.Errorf("expected OK got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -188,7 +197,7 @@ func TestHandler(t *testing.T) { func TestHandlerDecimal(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/pay/beta/transaction/1337.42?queryDecimal=123.456", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/pay/beta/transaction/1337.42?queryDecimal=123.456", strings.NewReader(`{ "data": { "id": "5d3607f4-7855-4bfc-b926-1e662c225f06", "type": "transaction", @@ -211,14 +220,20 @@ func TestHandlerDecimal(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() - if resp.StatusCode != 201 { + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + + if resp.StatusCode != http.StatusCreated { t.Errorf("expected OK got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -242,21 +257,27 @@ func assertDecimal(t *testing.T, got, want decimal.Decimal) { func TestHandlerPanic(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/pay/beta/payment-methods?include=paymentToken", nil) + req := httptest.NewRequest(http.MethodGet, "/pay/beta/payment-methods?include=paymentToken", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) log.Handler()(r).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } } @@ -264,21 +285,27 @@ func TestHandlerPanic(t *testing.T) { func TestHandlerError(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/pay/beta/payment-methods", nil) + req := httptest.NewRequest(http.MethodGet, "/pay/beta/payment-methods", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) log.Handler()(r).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } } diff --git a/http/jsonapi/generator/internal/poi/open-api_test.go b/http/jsonapi/generator/internal/poi/open-api_test.go index 4585aad94..4d928aa95 100644 --- a/http/jsonapi/generator/internal/poi/open-api_test.go +++ b/http/jsonapi/generator/internal/poi/open-api_test.go @@ -417,7 +417,7 @@ var cfgOAuth2 = &oauth2.Config{ } var cfgOIDC = &oidc.Config{ Description: "", - OpenIdConnectURL: "https://id.pace.cloud/auth/realms/pace/.well-known/openid-configuration", + OpenIDConnectURL: "https://id.pace.cloud/auth/realms/pace/.well-known/openid-configuration", } /* diff --git a/http/jsonapi/generator/internal/poi/poi_test.go b/http/jsonapi/generator/internal/poi/poi_test.go index da132ec1d..4c0f9644c 100644 --- a/http/jsonapi/generator/internal/poi/poi_test.go +++ b/http/jsonapi/generator/internal/poi/poi_test.go @@ -35,9 +35,11 @@ func (s *testService) CheckForPaceApp(ctx context.Context, w CheckForPaceAppResp if r.ParamFilterLatitude != 41.859194 { s.t.Errorf("expected ParamLatitude to be %f, got: %f", 41.859194, r.ParamFilterLatitude) } + if r.ParamFilterLongitude != -87.646984 { s.t.Errorf("expected ParamLongitude to be %f, got: %f", -87.646984, r.ParamFilterLatitude) } + if r.ParamFilterAppType != "fueling" { s.t.Errorf("expected ParamAppType to be %q, got: %q", "fueling", r.ParamFilterAppType) } @@ -225,7 +227,7 @@ func (s testAuthBackend) InitOIDC(cfgOIDC *oidc.Config) {} func TestHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/poi/beta/apps/query?"+ + req := httptest.NewRequest(http.MethodGet, "/poi/beta/apps/query?"+ "filter[latitude]=41.859194&filter[longitude]=-87.646984&filter[appType]=fueling", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) @@ -233,41 +235,50 @@ func TestHandler(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } var data struct { Data []map[string]interface{} `json:"data"` } - err := json.NewDecoder(resp.Body).Decode(&data) - if err != nil { + + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { t.Fatal(err) return } + if len(data.Data) != 10 { t.Error("Expected 10 apps") return } + if data.Data[0]["type"] != "locationBasedAppWithRefs" { t.Error("Expected type locationBasedAppWithRefs") return } + attributes, ok := data.Data[0]["attributes"].(map[string]interface{}) if !ok { t.Error("Expected attributes do be present") return } + if attributes["androidInstantAppUrl"] != "https://foobar.com" { t.Error(`Expected androidInstantAppUrl to be "https://foobar.com"`) } + if attributes["title"] != "some app" { t.Error(`Expected androidInstantAppUrl to be "some app"`) } + if attributes["appType"] != "some type" { t.Error(`Expected androidInstantAppUrl to be "some type"`) } @@ -276,48 +287,57 @@ func TestHandler(t *testing.T) { func TestHandlerWithTimeInQuery(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/poi/beta/apps?filter[since]=2020-05-06T12%3A22%3A54%2E000888456", nil) + req := httptest.NewRequest(http.MethodGet, "/poi/beta/apps?filter[since]=2020-05-06T12%3A22%3A54%2E000888456", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } var data struct { Data []map[string]interface{} `json:"data"` } - err := json.NewDecoder(resp.Body).Decode(&data) - if err != nil { + + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { t.Fatal(err) return } + if len(data.Data) != 10 { t.Error("Expected 10 apps") return } + if data.Data[0]["type"] != "locationBasedApp" { t.Error("Expected type locationBasedApp") return } + attributes, ok := data.Data[0]["attributes"].(map[string]interface{}) if !ok { t.Error("Expected attributes do be present") return } + if attributes["androidInstantAppUrl"] != "https://foobar.com" { t.Error(`Expected androidInstantAppUrl to be "https://foobar.com"`) } + if attributes["title"] != "some app" { t.Error(`Expected androidInstantAppUrl to be "some app"`) } + if attributes["appType"] != "some type" { t.Error(`Expected androidInstantAppUrl to be "some type"`) } @@ -326,7 +346,7 @@ func TestHandlerWithTimeInQuery(t *testing.T) { func TestCreatePolicyHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/poi/beta/policies", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/poi/beta/policies", strings.NewReader(`{ "data": { "id": "f106ac99-213c-4cf7-8c1b-1e841516026b", "type": "policies", @@ -355,11 +375,14 @@ func TestCreatePolicyHandler(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } } diff --git a/http/jsonapi/generator/internal/securitytest/security_test.go b/http/jsonapi/generator/internal/securitytest/security_test.go index 5b9d3cfa9..36920dc00 100644 --- a/http/jsonapi/generator/internal/securitytest/security_test.go +++ b/http/jsonapi/generator/internal/securitytest/security_test.go @@ -8,9 +8,11 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/http/security/apikey" - "github.com/stretchr/testify/require" ) type testService struct{} @@ -57,49 +59,69 @@ func TestSecurityBothAuthenticationMethods(t *testing.T) { // oauth2 OK, profileKey OK, canAuth: both w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result := w.Result() require.Equal(t, http.StatusOK, result.StatusCode) + err := result.Body.Close() + assert.NoError(t, err) + // oauth2 ok, profileKey OK, canAuth: none authBackend.canAuthProfileKey = false authBackend.canAuthOauth = false w = httptest.NewRecorder() - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusUnauthorized, result.StatusCode) + err = result.Body.Close() + assert.NoError(t, err) + // oauth2 400, profileKey OK, canAuth = oauth2 authBackend.canAuthProfileKey = false authBackend.canAuthOauth = true w = httptest.NewRecorder() authBackend.oauth2Code = http.StatusBadRequest - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusBadRequest, result.StatusCode) + err = result.Body.Close() + assert.NoError(t, err) + // oauth2 400, profileKey OK, canAuth = profileKey authBackend.canAuthProfileKey = true authBackend.canAuthOauth = false w = httptest.NewRecorder() authBackend.oauth2Code = http.StatusBadRequest - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusOK, result.StatusCode) + err = result.Body.Close() + assert.NoError(t, err) + // oauth2 400, profileKey 500, canAuth = both w = httptest.NewRecorder() authBackend.profileKeyCode = http.StatusInternalServerError authBackend.oauth2Code = http.StatusBadRequest authBackend.canAuthProfileKey = true authBackend.canAuthOauth = true - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() // Alphabetic order => get the error of the alphabetic first security scheme require.Equal(t, http.StatusBadRequest, result.StatusCode) + + err = result.Body.Close() + assert.NoError(t, err) } diff --git a/http/jsonapi/generator/route.go b/http/jsonapi/generator/route.go index 0b7b32935..11b9a49a1 100644 --- a/http/jsonapi/generator/route.go +++ b/http/jsonapi/generator/route.go @@ -24,7 +24,9 @@ func (r *route) parseURL() (err error) { if err != nil { return } + r.queryValues = r.url.Query() // cache query values + return } @@ -45,9 +47,11 @@ func (l *sortableRouteList) Less(i, j int) bool { if a, b := pathLen(elemI.url.Path), pathLen(elemJ.url.Path); a != b { return a > b } + if a, b := strings.Count(elemJ.url.Path, "{"), strings.Count(elemI.url.Path, "{"); a != b { return a > b } + return len(elemI.queryValues) > len(elemJ.queryValues) } diff --git a/http/jsonapi/generator/route_test.go b/http/jsonapi/generator/route_test.go index cec8bc409..2eb7e53e6 100644 --- a/http/jsonapi/generator/route_test.go +++ b/http/jsonapi/generator/route_test.go @@ -35,16 +35,21 @@ func TestSortableRouteList(t *testing.T) { "/beta/receipts/{transactionID}.{fileFormat}", } list := make(sortableRouteList, len(paths)) + for i, path := range paths { route := &route{pattern: path} require.NoError(t, route.parseURL()) + list[i] = route } + sort.Stable(&list) + actual := make([]string, len(paths)) for i, route := range list { actual[i] = route.pattern } + assert.Equal(t, []string{ "/beta/payment-method-kinds/applepay/authorize", "/beta/payment-methods/{paymentMethodId}/notification", diff --git a/http/jsonapi/middleware/error_middleware.go b/http/jsonapi/middleware/error_middleware.go index 2822b40a5..3dab52446 100644 --- a/http/jsonapi/middleware/error_middleware.go +++ b/http/jsonapi/middleware/error_middleware.go @@ -22,15 +22,19 @@ func (e *errorMiddleware) Write(b []byte) (int, error) { log.Req(e.req).Warn().Msgf("Error already sent, ignoring: %q", string(b)) return 0, nil } - repliesJsonApi := e.Header().Get("Content-Type") == runtime.JSONAPIContentType - requestsJsonApi := e.req.Header.Get("Accept") == runtime.JSONAPIContentType - if e.statusCode >= 400 && requestsJsonApi && !repliesJsonApi { + + repliesJSONAPI := e.Header().Get("Content-Type") == runtime.JSONAPIContentType + requestsJSONAPI := e.req.Header.Get("Accept") == runtime.JSONAPIContentType + + if e.statusCode >= 400 && requestsJSONAPI && !repliesJSONAPI { if e.hasBytes { log.Req(e.req).Warn().Msgf("Body already contains data from previous writes: ignoring: %q", string(b)) return 0, nil } + e.hasErr = true runtime.WriteError(e.ResponseWriter, e.statusCode, errors.New(strings.Trim(string(b), "\n"))) + return 0, nil } @@ -38,6 +42,7 @@ func (e *errorMiddleware) Write(b []byte) (int, error) { if err == nil && n > 0 { e.hasBytes = true } + return n, err } @@ -46,9 +51,9 @@ func (e *errorMiddleware) WriteHeader(code int) { e.ResponseWriter.WriteHeader(code) } -// ErrorMiddleware is a middleware that wraps http.ResponseWriter +// Error is a middleware that wraps http.ResponseWriter // such that it forces responses with status codes 4xx/5xx to have -// Content-Type: application/vnd.api+json +// Content-Type: application/vnd.api+json. func Error(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(&errorMiddleware{ResponseWriter: w, req: r}, r) diff --git a/http/jsonapi/middleware/error_middleware_test.go b/http/jsonapi/middleware/error_middleware_test.go index fb0c80a26..84a3297b0 100644 --- a/http/jsonapi/middleware/error_middleware_test.go +++ b/http/jsonapi/middleware/error_middleware_test.go @@ -8,31 +8,38 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/pace/bricks/http/jsonapi/runtime" ) const payload = "dummy response data" func TestErrorMiddleware(t *testing.T) { - for _, statusCode := range []int{200, 201, 400, 402, 500, 503} { + for _, statusCode := range []int{http.StatusOK, http.StatusCreated, http.StatusBadRequest, 402, 500, 503} { for _, responseContentType := range []string{"text/plain", "text/html", runtime.JSONAPIContentType} { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", responseContentType) w.WriteHeader(statusCode) _, _ = io.WriteString(w, payload) - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) resp := rec.Result() b, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } @@ -40,6 +47,7 @@ func TestErrorMiddleware(t *testing.T) { if statusCode != resp.StatusCode { t.Fatalf("status codes differ: expected %v, got %v", statusCode, resp.StatusCode) } + if resp.StatusCode < 400 || responseContentType == runtime.JSONAPIContentType { if payload != string(b) { t.Fatalf("payloads differ: expected %v, got %v", payload, string(b)) @@ -49,13 +57,14 @@ func TestErrorMiddleware(t *testing.T) { List runtime.Errors `json:"errors"` } - err := json.Unmarshal(b, &e) - if err != nil { + if err := json.Unmarshal(b, &e); err != nil { t.Fatal(err) } + if len(e.List) != 1 { t.Fatalf("expected only one record, got %v", len(e.List)) } + if payload != e.List[0].Title { t.Fatalf("error titles differ: expected %v, got %v", payload, e.List[0].Title) } @@ -67,7 +76,7 @@ func TestErrorMiddleware(t *testing.T) { func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(400) + w.WriteHeader(http.StatusBadRequest) w.Header().Set("Content-Type", "text/html") if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) @@ -81,28 +90,38 @@ func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) } - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) + resp := rec.Result() b, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + var e struct { List runtime.Errors `json:"errors"` } + if err := json.Unmarshal(b, &e); err != nil { t.Fatal(err) } + if len(e.List) != 1 { t.Fatalf("expected only one record, got %v", len(e.List)) } + if payload != e.List[0].Title { t.Fatalf("error titles differ: expected %v, got %v", payload, e.List[0].Title) } @@ -111,7 +130,7 @@ func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { func TestJsonApiErrorMiddlewareInvalidWriteOrder(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) } @@ -119,22 +138,29 @@ func TestJsonApiErrorMiddlewareInvalidWriteOrder(t *testing.T) { if ok && !jsonWriter.hasBytes { t.Fatal("expected hasBytes flag to be set") } - w.WriteHeader(400) + w.WriteHeader(http.StatusBadRequest) w.Header().Set("Content-Type", "text/plain") _, _ = io.WriteString(w, payload) // will get discarded - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) + resp := rec.Result() b, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + if payload != string(b) { t.Fatalf("bad response body, expected %q, got %q", payload, string(b)) } diff --git a/http/jsonapi/models_test.go b/http/jsonapi/models_test.go index 3ebd8e9e9..0f4dce073 100644 --- a/http/jsonapi/models_test.go +++ b/http/jsonapi/models_test.go @@ -113,6 +113,7 @@ func (b *Blog) JSONAPIRelationshipLinks(relation string) *Links { }, } } + if relation == "current_post" { return &Links{ "self": fmt.Sprintf("https://example.com/api/posts/%s", "3"), @@ -121,6 +122,7 @@ func (b *Blog) JSONAPIRelationshipLinks(relation string) *Links { }, } } + return nil } @@ -146,11 +148,13 @@ func (b *Blog) JSONAPIRelationshipMeta(relation string) *Meta { }, } } + if relation == "current_post" { return &Meta{ "detail": "extra current_post detail", } } + return nil } diff --git a/http/jsonapi/node.go b/http/jsonapi/node.go index 46b6f3cfd..ef07d7240 100644 --- a/http/jsonapi/node.go +++ b/http/jsonapi/node.go @@ -8,13 +8,13 @@ import ( "fmt" ) -// Payloader is used to encapsulate the One and Many payload types +// Payloader is used to encapsulate the One and Many payload types. type Payloader interface { clearIncluded() } // OnePayload is used to represent a generic JSON API payload where a single -// resource (Node) was included as an {} in the "data" key +// resource (Node) was included as an {} in the "data" key. type OnePayload struct { Data *Node `json:"data"` Included []*Node `json:"included,omitempty"` @@ -27,7 +27,7 @@ func (p *OnePayload) clearIncluded() { } // ManyPayload is used to represent a generic JSON API payload where many -// resources (Nodes) were included in an [] in the "data" key +// resources (Nodes) were included in an [] in the "data" key. type ManyPayload struct { Data []*Node `json:"data"` Included []*Node `json:"included,omitempty"` @@ -39,7 +39,7 @@ func (p *ManyPayload) clearIncluded() { p.Included = []*Node{} } -// Node is used to represent a generic JSON API Resource +// Node is used to represent a generic JSON API Resource. type Node struct { Type string `json:"type"` ID string `json:"id,omitempty"` @@ -50,7 +50,7 @@ type Node struct { Meta *Meta `json:"meta,omitempty"` } -// RelationshipOneNode is used to represent a generic has one JSON API relation +// RelationshipOneNode is used to represent a generic has one JSON API relation. type RelationshipOneNode struct { Data *Node `json:"data"` Links *Links `json:"links,omitempty"` @@ -58,7 +58,7 @@ type RelationshipOneNode struct { } // RelationshipManyNode is used to represent a generic has many JSON API -// relation +// relation. type RelationshipManyNode struct { Data []*Node `json:"data"` Links *Links `json:"links,omitempty"` @@ -83,11 +83,12 @@ func (l *Links) validate() (err error) { if !(isString || isLink) { return fmt.Errorf( - "The %s member of the links object was not a string or link object", + "the %s member of the links object was not a string or link object", k, ) } } + return } @@ -115,12 +116,12 @@ type RelationshipLinkable interface { type Meta map[string]interface{} // Metable is used to include document meta in response data -// e.g. {"foo": "bar"} +// e.g. {"foo": "bar"}. type Metable interface { JSONAPIMeta() *Meta } -// RelationshipMetable is used to include relationship meta in response data +// RelationshipMetable is used to include relationship meta in response data. type RelationshipMetable interface { // JSONRelationshipMeta will be invoked for each relationship with the corresponding relation name (e.g. `comments`) JSONAPIRelationshipMeta(relation string) *Meta diff --git a/http/jsonapi/request.go b/http/jsonapi/request.go index ab331cb43..0c1ecce9c 100644 --- a/http/jsonapi/request.go +++ b/http/jsonapi/request.go @@ -18,47 +18,49 @@ import ( ) const ( - unsupportedStructTagMsg = "Unsupported jsonapi tag annotation, %s" + unsupportedStructTagMsg = "unsupported jsonapi tag annotation, %s" ) var ( // ErrInvalidTime is returned when a struct has a time.Time type field, but // the JSON value was not a unix timestamp integer. - ErrInvalidTime = errors.New("Only numbers can be parsed as dates, unix timestamps") + ErrInvalidTime = errors.New("only numbers can be parsed as dates, unix timestamps") // ErrInvalidISO8601 is returned when a struct has a time.Time type field and includes // "iso8601" in the tag spec, but the JSON value was not an ISO8601 timestamp string. - ErrInvalidISO8601 = errors.New("Only strings can be parsed as dates, ISO8601 timestamps") + ErrInvalidISO8601 = errors.New("only strings can be parsed as dates, ISO8601 timestamps") // ErrUnknownFieldNumberType is returned when the JSON value was a float // (numeric) but the Struct field was a non numeric type (i.e. not int, uint, - // float, etc) - ErrUnknownFieldNumberType = errors.New("The struct field was not of a known number type") + // float, etc). + ErrUnknownFieldNumberType = errors.New("the struct field was not of a known number type") // ErrInvalidType is returned when the given type is incompatible with the expected type. - ErrInvalidType = errors.New("Invalid type provided") // I wish we used punctuation. + ErrInvalidType = errors.New("invalid type provided") // I wish we used punctuation. ) -// ErrUnsupportedPtrType is returned when the Struct field was a pointer but -// the JSON value was of a different type -type ErrUnsupportedPtrType struct { +// UnsupportedPtrTypeError is returned when the Struct field was a pointer but +// the JSON value was of a different type. +type UnsupportedPtrTypeError struct { rf reflect.Value t reflect.Type structField reflect.StructField } -func (eupt ErrUnsupportedPtrType) Error() string { +func (eupt UnsupportedPtrTypeError) Error() string { typeName := eupt.t.Elem().Name() kind := eupt.t.Elem().Kind() + if kind.String() != "" && kind.String() != typeName { typeName = fmt.Sprintf("%s (%s)", typeName, kind.String()) } + return fmt.Sprintf( "jsonapi: Can't unmarshal %+v (%s) to struct field `%s`, which is a pointer to `%s`", eupt.rf, eupt.rf.Type().Kind(), eupt.structField.Name, typeName, ) } -func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField reflect.StructField) error { - return ErrUnsupportedPtrType{rf, t, structField} +func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField reflect.StructField) UnsupportedPtrTypeError { + return UnsupportedPtrTypeError{rf, t, structField} } // UnmarshalPayload converts an io into a struct instance using jsonapi tags on @@ -83,7 +85,7 @@ func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField refl // // ...do stuff with your blog... // // w.Header().Set("Content-Type", jsonapi.MediaType) -// w.WriteHeader(201) +// w.WriteHeader(http.StatusCreated) // // if err := jsonapi.MarshalPayload(w, blog); err != nil { // http.Error(w, err.Error(), 500) @@ -102,6 +104,7 @@ func UnmarshalPayload(in io.Reader, model interface{}) error { if payload.Included != nil { includedMap := make(map[string]*Node) + for _, included := range payload.Included { key := fmt.Sprintf("%s,%s", included.Type, included.ID) includedMap[key] = included @@ -109,6 +112,7 @@ func UnmarshalPayload(in io.Reader, model interface{}) error { return unmarshalNode(payload.Data, reflect.ValueOf(model), &includedMap) } + return unmarshalNode(payload.Data, reflect.ValueOf(model), nil) } @@ -133,10 +137,11 @@ func UnmarshalManyPayload(in io.Reader, t reflect.Type) ([]interface{}, error) { for _, data := range payload.Data { model := reflect.New(t.Elem()) - err := unmarshalNode(data, model, &includedMap) - if err != nil { + + if err := unmarshalNode(data, model, &includedMap); err != nil { return nil, err } + models = append(models, model.Interface()) } @@ -157,6 +162,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) for i := 0; i < modelValue.NumField(); i++ { fieldType := modelType.Field(i) + tag := fieldType.Tag.Get("jsonapi") if tag == "" { continue @@ -186,10 +192,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) // Check the JSON API Type if data.Type != args[1] { er = fmt.Errorf( - "Trying to Unmarshal an object of type %#v, but %#v does not match", + "trying to Unmarshal an object of type %#v, but %#v does not match", data.Type, args[1], ) + break } @@ -251,6 +258,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } structField := fieldType + value, err := unmarshalAttribute(attribute, args, structField, fieldValue) if err != nil { er = err @@ -273,8 +281,13 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) - json.NewEncoder(buf).Encode(data.Relationships[args[1]]) // nolint: errcheck - json.NewDecoder(buf).Decode(relationship) // nolint: errcheck + if err := json.NewEncoder(buf).Encode(data.Relationships[args[1]]); err != nil { + return err + } + + if err := json.NewDecoder(buf).Decode(relationship); err != nil { + return err + } data := relationship.Data models := reflect.New(fieldValue.Type()).Elem() @@ -301,10 +314,13 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) - json.NewEncoder(buf).Encode( // nolint: errcheck - data.Relationships[args[1]], - ) - json.NewDecoder(buf).Decode(relationship) // nolint: errcheck + if err := json.NewEncoder(buf).Encode(data.Relationships[args[1]]); err != nil { + return err + } + + if err := json.NewDecoder(buf).Decode(relationship); err != nil { + return err + } /* http://jsonapi.org/format/#document-resource-object-relationships @@ -327,9 +343,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } fieldValue.Set(m) - } - } else { er = fmt.Errorf(unsupportedStructTagMsg, annotation) } @@ -357,8 +371,8 @@ func assign(field, value reflect.Value) { // initialize pointer so it's value // can be set by assignValue field.Set(reflect.New(field.Type().Elem())) - field = field.Elem() + field = field.Elem() } assignValue(field, value) @@ -390,75 +404,67 @@ func unmarshalAttribute( args []string, structField reflect.StructField, fieldValue reflect.Value, -) (value reflect.Value, err error) { +) (reflect.Value, error) { var attribute interface{} - err = json.Unmarshal(rawAttribute, &attribute) - if err != nil { + + if err := json.Unmarshal(rawAttribute, &attribute); err != nil { return reflect.Value{}, err } - value = reflect.ValueOf(attribute) + value := reflect.ValueOf(attribute) fieldType := structField.Type // decimal.Decimal and *decimal.Decimal if fieldValue.Type() == reflect.TypeOf(decimal.Decimal{}) || fieldValue.Type() == reflect.TypeOf(new(decimal.Decimal)) { - value, err = handleDecimal(rawAttribute) - return + return handleDecimal(rawAttribute) } // map[string][]string if fieldValue.Type() == reflect.TypeOf(map[string][]string{}) { - value, err = handleMapStringSlice(rawAttribute) - return + return handleMapStringSlice(rawAttribute) } // Handle field of type time.Time if fieldValue.Type() == reflect.TypeOf(time.Time{}) || fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - value, err = handleTime(attribute, args, fieldValue) - return + return handleTime(attribute, args, fieldValue) } // Handle field of type struct if fieldValue.Type().Kind() == reflect.Struct { - value, err = handleStruct(attribute, fieldValue) - return + return handleStruct(attribute, fieldValue) } // Handle field containing slices if fieldValue.Type().Kind() == reflect.Slice { value = reflect.New(fieldValue.Type()) - err = json.Unmarshal(rawAttribute, value.Interface()) - return + return value, json.Unmarshal(rawAttribute, value.Interface()) } // JSON value was a float (numeric) if value.Kind() == reflect.Float64 { - value, err = handleNumeric(attribute, fieldType, fieldValue) - return + return handleNumeric(attribute, fieldType, fieldValue) } // Field was a Pointer type if fieldValue.Kind() == reflect.Ptr { - value, err = handlePointer(attribute, args, fieldType, fieldValue, structField) - return + return handlePointer(attribute, args, fieldType, fieldValue, structField) } // As a final catch-all, ensure types line up to avoid a runtime panic. if fieldValue.Kind() != value.Kind() { - err = fmt.Errorf("got value %q expected type %v: %w", value, fieldType, ErrInvalidType) - return + return value, fmt.Errorf("got value %q expected type %v: %w", value, fieldType, ErrInvalidType) } - return + return value, nil } func handleDecimal(attribute json.RawMessage) (reflect.Value, error) { var dec decimal.Decimal - err := json.Unmarshal(attribute, &dec) - if err != nil { - return reflect.Value{}, fmt.Errorf("can't decode decimal from value %q: %v", string(attribute), err) + + if err := json.Unmarshal(attribute, &dec); err != nil { + return reflect.Value{}, fmt.Errorf("can't decode decimal from value %q: %w", string(attribute), err) } return reflect.ValueOf(dec), nil @@ -466,9 +472,9 @@ func handleDecimal(attribute json.RawMessage) (reflect.Value, error) { func handleMapStringSlice(attribute json.RawMessage) (reflect.Value, error) { var m map[string][]string - err := json.Unmarshal(attribute, &m) - if err != nil { - return reflect.Value{}, fmt.Errorf("can't decode map string slice from value %q: %v", string(attribute), err) + + if err := json.Unmarshal(attribute, &m); err != nil { + return reflect.Value{}, fmt.Errorf("can't decode map string slice from value %q: %w", string(attribute), err) } return reflect.ValueOf(m), nil @@ -476,6 +482,7 @@ func handleMapStringSlice(attribute json.RawMessage) (reflect.Value, error) { func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) (reflect.Value, error) { var isIso8601 bool + v := reflect.ValueOf(attribute) if len(args) > 2 { @@ -489,7 +496,7 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) if isIso8601 { var tm string if v.Kind() == reflect.String { - tm = v.Interface().(string) + tm, _ = v.Interface().(string) } else { return reflect.ValueOf(time.Now()), ErrInvalidISO8601 } @@ -509,7 +516,8 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) var at int64 if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) + atTmp, _ := v.Interface().(float64) + at = int64(atTmp) } else if v.Kind() == reflect.Int { at = v.Int() } else { @@ -527,7 +535,7 @@ func handleNumeric( fieldValue reflect.Value, ) (reflect.Value, error) { v := reflect.ValueOf(attribute) - floatValue := v.Interface().(float64) + floatValue, _ := v.Interface().(float64) var kind reflect.Kind if fieldValue.Kind() == reflect.Ptr { @@ -584,12 +592,13 @@ func handleNumeric( func handlePointer( attribute interface{}, - args []string, + _ []string, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField, ) (reflect.Value, error) { t := fieldValue.Type() + var concreteVal reflect.Value if attribute == nil { @@ -605,11 +614,13 @@ func handlePointer( concreteVal = reflect.ValueOf(&cVal) case map[string]interface{}: var err error + concreteVal, err = handleStruct(attribute, fieldValue) if err != nil { return reflect.Value{}, newErrUnsupportedPtrType( reflect.ValueOf(attribute), fieldType, structField) } + return concreteVal, err default: return reflect.Value{}, newErrUnsupportedPtrType( diff --git a/http/jsonapi/request_test.go b/http/jsonapi/request_test.go index 9ace13911..9e77d2139 100644 --- a/http/jsonapi/request_test.go +++ b/http/jsonapi/request_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshall_attrStringSlice(t *testing.T) { @@ -34,6 +35,7 @@ func TestUnmarshall_attrStringSlice(t *testing.T) { }, }, } + b, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -53,9 +55,11 @@ func TestUnmarshall_attrStringSlice(t *testing.T) { if out.Decimal1.String() != "9.9999999999999999999" { t.Fatalf("Expected json dec1 data to be %#v got: %#v", "9.9999999999999999999", out.Decimal1.String()) } + if out.Decimal2.String() != "9.9999999999999999999" { t.Fatalf("Expected json dec2 data to be %#v got: %#v", "9.9999999999999999999", out.Decimal2.String()) } + if out.Decimal3.String() != "10" { t.Fatalf("Expected json dec2 data to be %#v got: %#v", 10, out.Decimal3.String()) } @@ -104,6 +108,7 @@ func TestUnmarshall_MapStringSlice(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { out := &Book{} + b, err := json.Marshal(tc.input) if err != nil { t.Fatal(err) @@ -123,18 +128,26 @@ func TestUnmarshalToStructWithPointerAttr(t *testing.T) { "int-val": json.RawMessage(`8`), "float-val": json.RawMessage(`1.1`), } - if err := UnmarshalPayload(sampleWithPointerPayload(in), out); err != nil { + + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatal(err) } + if *out.Name != "The name" { t.Fatalf("Error unmarshalling to string ptr") } + if !*out.IsActive { t.Fatalf("Error unmarshalling to bool ptr") } + if *out.IntVal != 8 { t.Fatalf("Error unmarshalling to int ptr") } + if *out.FloatVal != 1.1 { t.Fatalf("Error unmarshalling to float ptr") } @@ -156,7 +169,10 @@ func TestUnmarshalPayloadWithPointerID(t *testing.T) { out := new(WithPointer) attrs := map[string]json.RawMessage{} - if err := UnmarshalPayload(sampleWithPointerPayload(attrs), out); err != nil { + payload, err := sampleWithPointerPayload(attrs) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatalf("Error unmarshalling to Foo") } @@ -164,6 +180,7 @@ func TestUnmarshalPayloadWithPointerID(t *testing.T) { if out.ID == nil { t.Fatalf("Error unmarshalling; expected ID ptr to be not nil") } + if e, a := uint64(2), *out.ID; e != a { t.Fatalf("Was expecting the ID to have a value of %d, got %d", e, a) } @@ -176,7 +193,10 @@ func TestUnmarshalPayloadWithPointerAttr_AbsentVal(t *testing.T) { "is-active": json.RawMessage(`true`), } - if err := UnmarshalPayload(sampleWithPointerPayload(in), out); err != nil { + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatalf("Error unmarshalling to Foo") } @@ -198,15 +218,19 @@ func TestUnmarshalToStructWithPointerAttr_BadType_bool(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal true (bool) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } - if _, ok := err.(ErrUnsupportedPtrType); !ok { + + if _, ok := err.(UnsupportedPtrTypeError); !ok { //nolint:errorlint t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } } @@ -218,15 +242,19 @@ func TestUnmarshalToStructWithPointerAttr_BadType_MapPtr(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal map[a:5] (map) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } - if _, ok := err.(ErrUnsupportedPtrType); !ok { + + if _, ok := err.(UnsupportedPtrTypeError); !ok { //nolint:errorlint t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } } @@ -238,15 +266,19 @@ func TestUnmarshalToStructWithPointerAttr_BadType_Struct(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal map[A:5] (map) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } - if _, ok := err.(ErrUnsupportedPtrType); !ok { + + if _, ok := err.(UnsupportedPtrTypeError); !ok { //nolint:errorlint t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } } @@ -258,15 +290,19 @@ func TestUnmarshalToStructWithPointerAttr_BadType_IntSlice(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal [4 5] (slice) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } - if _, ok := err.(ErrUnsupportedPtrType); !ok { + + if _, ok := err.(UnsupportedPtrTypeError); !ok { //nolint:errorlint t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } } @@ -285,6 +321,7 @@ func TestStringPointerField(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -299,6 +336,7 @@ func TestStringPointerField(t *testing.T) { if book.Description == nil { t.Fatal("Was not expecting a nil pointer for book.Description") } + if expected, actual := description, *book.Description; expected != actual { t.Fatalf("Was expecting descript to be `%s`, got `%s`", expected, actual) } @@ -306,10 +344,12 @@ func TestStringPointerField(t *testing.T) { func TestMalformedTag(t *testing.T) { out := new(BadModel) - err := UnmarshalPayload(samplePayload(), out) - if err == nil || err != ErrBadJSONAPIStructTag { - t.Fatalf("Did not error out with wrong number of arguments in tag") - } + + payload, err := samplePayload() + require.NoError(t, err) + + err = UnmarshalPayload(payload, out) + require.ErrorIs(t, err, ErrBadJSONAPIStructTag) } func TestUnmarshalInvalidJSON(t *testing.T) { @@ -317,7 +357,6 @@ func TestUnmarshalInvalidJSON(t *testing.T) { out := new(Blog) err := UnmarshalPayload(in, out) - if err == nil { t.Fatalf("Did not error out the invalid JSON.") } @@ -330,7 +369,7 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { Error error }{ // The `Field` values here correspond to the `ModelBadTypes` jsonapi fields. {Field: "string_field", BadValue: json.RawMessage(`0`), Error: ErrUnknownFieldNumberType}, // Expected string. - {Field: "float_field", BadValue: json.RawMessage(`"A string."`), Error: errors.New("got value \"A string.\" expected type float64: Invalid type provided")}, // Expected float64. + {Field: "float_field", BadValue: json.RawMessage(`"A string."`), Error: errors.New("got value \"A string.\" expected type float64: invalid type provided")}, // Expected float64. {Field: "time_field", BadValue: json.RawMessage(`"A string."`), Error: ErrInvalidTime}, // Expected int64. {Field: "time_ptr_field", BadValue: json.RawMessage(`"A string."`), Error: ErrInvalidTime}, // Expected *time / int64. } @@ -341,20 +380,23 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { in[test.Field] = test.BadValue expectedErrorMessage := test.Error.Error() - err := UnmarshalPayload(samplePayloadWithBadTypes(in), out) + payload, err := samplePayloadWithBadTypes(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("(Test %d) Expected error due to invalid type.", i+1) } - if err.Error() != expectedErrorMessage { - t.Fatalf("(Test %d) Unexpected error message: %q \nexpected: %q", i+1, expectedErrorMessage, err.Error()) - } + + require.Equal(t, expectedErrorMessage, err.Error()) }) } } func TestUnmarshalSetsID(t *testing.T) { - in := samplePayloadWithID() + in, err := samplePayloadWithID() + require.NoError(t, err) + out := new(Blog) if err := UnmarshalPayload(in, out); err != nil { @@ -368,21 +410,24 @@ func TestUnmarshalSetsID(t *testing.T) { func TestUnmarshal_nonNumericID(t *testing.T) { data := samplePayloadWithoutIncluded() - data["data"].(map[string]interface{})["id"] = "non-numeric-id" + + dataMap, ok := data["data"].(map[string]interface{}) + if !ok { + t.Fatal("data is not a map") + } + + dataMap["id"] = "non-numeric-id" + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) } + in := bytes.NewReader(payload) out := new(Post) - if err := UnmarshalPayload(in, out); err != ErrBadJSONAPIID { - t.Fatalf( - "Was expecting a `%s` error, got `%s`", - ErrBadJSONAPIID, - err, - ) - } + err = UnmarshalPayload(in, out) + require.ErrorIs(t, err, ErrBadJSONAPIID) } func TestUnmarshalSetsAttrs(t *testing.T) { @@ -411,8 +456,8 @@ func TestUnmarshalParsesISO8601(t *testing.T) { } in := bytes.NewBuffer(nil) - err := json.NewEncoder(in).Encode(payload) - if err != nil { + + if err := json.NewEncoder(in).Encode(payload); err != nil { log.Fatal(err) } @@ -440,8 +485,8 @@ func TestUnmarshalParsesISO8601TimePointer(t *testing.T) { } in := bytes.NewBuffer(nil) - err := json.NewEncoder(in).Encode(payload) - if err != nil { + + if err := json.NewEncoder(in).Encode(payload); err != nil { t.Fatal(err) } @@ -469,16 +514,15 @@ func TestUnmarshalInvalidISO8601(t *testing.T) { } in := bytes.NewBuffer(nil) - err := json.NewEncoder(in).Encode(payload) - if err != nil { + + if err := json.NewEncoder(in).Encode(payload); err != nil { t.Fatal(err) } out := new(Timestamp) - if err := UnmarshalPayload(in, out); err != ErrInvalidISO8601 { - t.Fatalf("Expected ErrInvalidISO8601, got %v", err) - } + err := UnmarshalPayload(in, out) + require.ErrorIs(t, err, ErrInvalidISO8601) } func TestUnmarshalRelationshipsWithoutIncluded(t *testing.T) { @@ -486,6 +530,7 @@ func TestUnmarshalRelationshipsWithoutIncluded(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Post) @@ -536,6 +581,7 @@ func TestUnmarshalNullRelationship(t *testing.T) { }, }, } + data, err := json.Marshal(sample) if err != nil { t.Fatal(err) @@ -569,6 +615,7 @@ func TestUnmarshalNullRelationshipInSlice(t *testing.T) { }, }, } + data, err := json.Marshal(sample) if err != nil { t.Fatal(err) @@ -703,7 +750,10 @@ func TestUnmarshalNestedRelationshipsSideloaded(t *testing.T) { func TestUnmarshalNestedRelationshipsEmbedded_withClientIDs(t *testing.T) { model := new(Blog) - if err := UnmarshalPayload(samplePayload(), model); err != nil { + payload, err := samplePayload() + require.NoError(t, err) + + if err := UnmarshalPayload(payload, model); err != nil { t.Fatal(err) } @@ -713,7 +763,11 @@ func TestUnmarshalNestedRelationshipsEmbedded_withClientIDs(t *testing.T) { } func unmarshalSamplePayload() (*Blog, error) { - in := samplePayload() + in, err := samplePayload() + if err != nil { + return nil, err + } + out := new(Blog) if err := UnmarshalPayload(in, out); err != nil { @@ -749,6 +803,7 @@ func TestUnmarshalManyPayload(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) posts, err := UnmarshalManyPayload(in, reflect.TypeOf(new(Post))) @@ -805,6 +860,7 @@ func TestManyPayload_withLinks(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) payload := new(ManyPayload) @@ -822,6 +878,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := firstPageURL, first; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyFirstPage, e, a) } @@ -830,6 +887,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := prevPageURL, prev; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyPreviousPage, e, a) } @@ -838,6 +896,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := nextPageURL, next; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyNextPage, e, a) } @@ -846,6 +905,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := lastPageURL, last; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyLastPage, e, a) } @@ -869,6 +929,7 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -883,9 +944,11 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { if expected, actual := customInt, customAttributeTypes.Int; expected != actual { t.Fatalf("Was expecting custom int to be `%d`, got `%d`", expected, actual) } + if expected, actual := customInt, *customAttributeTypes.IntPtr; expected != actual { t.Fatalf("Was expecting custom int pointer to be `%d`, got `%d`", expected, actual) } + if customAttributeTypes.IntPtrNull != nil { t.Fatalf("Was expecting custom int pointer to be , got `%d`", customAttributeTypes.IntPtrNull) } @@ -893,6 +956,7 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { if expected, actual := customFloat, customAttributeTypes.Float; expected != actual { t.Fatalf("Was expecting custom float to be `%f`, got `%f`", expected, actual) } + if expected, actual := customString, customAttributeTypes.String; expected != actual { t.Fatalf("Was expecting custom string to be `%s`, got `%s`", expected, actual) } @@ -912,6 +976,7 @@ func TestUnmarshalCustomTypeAttributes_ErrInvalidType(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -919,12 +984,13 @@ func TestUnmarshalCustomTypeAttributes_ErrInvalidType(t *testing.T) { // Parse JSON API payload customAttributeTypes := new(CustomAttributeTypes) + err = UnmarshalPayload(bytes.NewReader(payload), customAttributeTypes) if err == nil { t.Fatal("Expected an error unmarshalling the payload due to type mismatch, got none") } - e := errors.New("got value \"bad\" expected type jsonapi.CustomIntType: Invalid type provided") + e := errors.New("got value \"bad\" expected type jsonapi.CustomIntType: invalid type provided") if err.Error() != e.Error() { t.Fatalf("Expected error to be %q,\nwas %q", e, err) } @@ -963,7 +1029,7 @@ func samplePayloadWithoutIncluded() map[string]interface{} { } } -func samplePayload() io.Reader { +func samplePayload() (io.Reader, error) { payload := &OnePayload{ Data: &Node{ Type: "blogs", @@ -1028,12 +1094,15 @@ func samplePayload() io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } -func samplePayloadWithID() io.Reader { +func samplePayloadWithID() (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1046,12 +1115,15 @@ func samplePayloadWithID() io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } -func samplePayloadWithBadTypes(m map[string]json.RawMessage) io.Reader { +func samplePayloadWithBadTypes(m map[string]json.RawMessage) (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1061,12 +1133,15 @@ func samplePayloadWithBadTypes(m map[string]json.RawMessage) io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } -func sampleWithPointerPayload(m map[string]json.RawMessage) io.Reader { +func sampleWithPointerPayload(m map[string]json.RawMessage) (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1076,9 +1151,12 @@ func sampleWithPointerPayload(m map[string]json.RawMessage) io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } func testModel() *Blog { @@ -1153,8 +1231,8 @@ func samplePayloadWithSideloaded() io.Reader { testModel := testModel() out := bytes.NewBuffer(nil) - err := MarshalPayload(out, testModel) - if err != nil { + + if err := MarshalPayload(out, testModel); err != nil { panic(err) } @@ -1163,14 +1241,14 @@ func samplePayloadWithSideloaded() io.Reader { func sampleSerializedEmbeddedTestModel() *Blog { out := bytes.NewBuffer(nil) - err := MarshalOnePayloadEmbedded(out, testModel()) - if err != nil { + + if err := MarshalOnePayloadEmbedded(out, testModel()); err != nil { panic(err) } blog := new(Blog) - err = UnmarshalPayload(out, blog) - if err != nil { + + if err := UnmarshalPayload(out, blog); err != nil { panic(err) } @@ -1182,11 +1260,13 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { Firstname string `jsonapi:"attr,firstname"` Surname string `jsonapi:"attr,surname"` } + type Movie struct { ID string `jsonapi:"primary,movies"` Name string `jsonapi:"attr,name"` Director *Director `jsonapi:"attr,director"` } + sample := map[string]interface{}{ "data": map[string]interface{}{ "type": "movies", @@ -1205,6 +1285,7 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Movie) @@ -1215,9 +1296,11 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { if out.Name != "The Shawshank Redemption" { t.Fatalf("expected out.Name to be `The Shawshank Redemption`, but got `%s`", out.Name) } + if out.Director.Firstname != "Frank" { t.Fatalf("expected out.Director.Firstname to be `Frank`, but got `%s`", out.Director.Firstname) } + if out.Director.Surname != "Darabont" { t.Fatalf("expected out.Director.Surname to be `Darabont`, but got `%s`", out.Director.Surname) } @@ -1265,6 +1348,7 @@ func TestUnmarshalNestedStruct(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Company) @@ -1369,6 +1453,7 @@ func TestUnmarshalNestedStructSlice(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Company) diff --git a/http/jsonapi/response.go b/http/jsonapi/response.go index b4c967f23..e97518955 100644 --- a/http/jsonapi/response.go +++ b/http/jsonapi/response.go @@ -19,7 +19,7 @@ import ( var ( // ErrBadJSONAPIStructTag is returned when the Struct field's JSON API // annotation is invalid. - ErrBadJSONAPIStructTag = errors.New("Bad jsonapi struct tag format") + ErrBadJSONAPIStructTag = errors.New("bad jsonapi struct tag format") // ErrBadJSONAPIID is returned when the Struct JSON API annotated "id" field // was not a valid numeric type. ErrBadJSONAPIID = errors.New( @@ -96,6 +96,7 @@ func Marshal(models interface{}) (Payloader, error) { if er := jl.validate(); er != nil { return nil, er } + payload.Links = linkableModels.JSONAPILinks() } @@ -109,6 +110,7 @@ func Marshal(models interface{}) (Payloader, error) { if reflect.Indirect(vals).Kind() != reflect.Struct { return nil, ErrUnexpectedType } + return marshalOne(models) default: return nil, ErrUnexpectedType @@ -127,6 +129,7 @@ func MarshalPayloadWithoutIncluded(w io.Writer, model interface{}) error { if err != nil { return err } + payload.clearIncluded() return json.NewEncoder(w).Encode(payload) @@ -142,6 +145,7 @@ func marshalOne(model interface{}) (*OnePayload, error) { if err != nil { return nil, err } + payload := &OnePayload{Data: rootNode} payload.Included = nodeMapValues(&included) @@ -163,8 +167,10 @@ func marshalMany(models []interface{}) (*ManyPayload, error) { if err != nil { return nil, err } + payload.Data = append(payload.Data, node) } + payload.Included = nodeMapValues(&included) return payload, nil @@ -202,6 +208,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, node := new(Node) var er error + value := reflect.ValueOf(model) if value.IsNil() { return nil, nil @@ -212,6 +219,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, for i := 0; i < modelValue.NumField(); i++ { structField := modelValue.Type().Field(i) + tag := structField.Tag.Get(annotationJSONAPI) if tag == "" { continue @@ -250,27 +258,77 @@ func visitModelNode(model interface{}, included *map[string]*Node, // Handle allowed types switch kind { case reflect.String: - node.ID = v.Interface().(string) + node.ID, _ = v.Interface().(string) case reflect.Int: - node.ID = strconv.FormatInt(int64(v.Interface().(int)), 10) + val, ok := v.Interface().(int) + if !ok { + return nil, errors.New("could not assert int") + } + + node.ID = strconv.FormatInt(int64(val), 10) case reflect.Int8: - node.ID = strconv.FormatInt(int64(v.Interface().(int8)), 10) + val, ok := v.Interface().(int8) + if !ok { + return nil, errors.New("could not assert int8") + } + + node.ID = strconv.FormatInt(int64(val), 10) case reflect.Int16: - node.ID = strconv.FormatInt(int64(v.Interface().(int16)), 10) + val, ok := v.Interface().(int16) + if !ok { + return nil, errors.New("could not assert int16") + } + + node.ID = strconv.FormatInt(int64(val), 10) case reflect.Int32: - node.ID = strconv.FormatInt(int64(v.Interface().(int32)), 10) + val, ok := v.Interface().(int32) + if !ok { + return nil, errors.New("could not assert int32") + } + + node.ID = strconv.FormatInt(int64(val), 10) case reflect.Int64: - node.ID = strconv.FormatInt(v.Interface().(int64), 10) + val, ok := v.Interface().(int64) + if !ok { + return nil, errors.New("could not assert int64") + } + + node.ID = strconv.FormatInt(val, 10) case reflect.Uint: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint)), 10) + val, ok := v.Interface().(uint) + if !ok { + return nil, errors.New("could not assert uint") + } + + node.ID = strconv.FormatUint(uint64(val), 10) case reflect.Uint8: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint8)), 10) + val, ok := v.Interface().(uint8) + if !ok { + return nil, errors.New("could not assert uint8") + } + + node.ID = strconv.FormatUint(uint64(val), 10) case reflect.Uint16: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint16)), 10) + val, ok := v.Interface().(uint16) + if !ok { + return nil, errors.New("could not assert uint16") + } + + node.ID = strconv.FormatUint(uint64(val), 10) case reflect.Uint32: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint32)), 10) + val, ok := v.Interface().(uint32) + if !ok { + return nil, errors.New("could not assert uint32") + } + + node.ID = strconv.FormatUint(uint64(val), 10) case reflect.Uint64: - node.ID = strconv.FormatUint(v.Interface().(uint64), 10) + val, ok := v.Interface().(uint64) + if !ok { + return nil, errors.New("could not assert uint64") + } + + node.ID = strconv.FormatUint(val, 10) default: // We had a JSON float (numeric), but our field was not one of the // allowed numeric types @@ -304,14 +362,19 @@ func visitModelNode(model interface{}, included *map[string]*Node, if node.Attributes == nil { node.Attributes = make(map[string]json.RawMessage) } + var err error if fieldValue.Type() == reflect.TypeOf(decimal.Decimal{}) { - d := fieldValue.Interface().(decimal.Decimal) + d, ok := fieldValue.Interface().(decimal.Decimal) + if !ok { + return nil, fmt.Errorf("could not assert decimal.Decimal") + } if !decimal.MarshalJSONWithoutQuotes { return nil, fmt.Errorf("decimal.MarshalJSONWithoutQuotes needs to be turned on to export decimals as numbers") } + node.Attributes[args[1]] = json.RawMessage(d.String()) } else if fieldValue.Type() == reflect.TypeOf(new(decimal.Decimal)) { // A decimal pointer may be nil @@ -322,15 +385,22 @@ func visitModelNode(model interface{}, included *map[string]*Node, node.Attributes[args[1]] = []byte("null") } else { - d := fieldValue.Interface().(*decimal.Decimal) + d, ok := fieldValue.Interface().(*decimal.Decimal) + if !ok { + return nil, fmt.Errorf("could not assert decimal.Decimal") + } if !decimal.MarshalJSONWithoutQuotes { return nil, fmt.Errorf("decimal.MarshalJSONWithoutQuotes needs to be turned on to export decimals as numbers") } + node.Attributes[args[1]] = json.RawMessage(d.String()) } } else if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - t := fieldValue.Interface().(time.Time) + t, ok := fieldValue.Interface().(time.Time) + if !ok { + return nil, fmt.Errorf("could not assert time.Time") + } if t.IsZero() { continue @@ -341,6 +411,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else { node.Attributes[args[1]], err = json.Marshal(t.Unix()) } + if err != nil { return nil, err } @@ -353,7 +424,10 @@ func visitModelNode(model interface{}, included *map[string]*Node, node.Attributes[args[1]] = []byte("null") } else { - tm := fieldValue.Interface().(*time.Time) + tm, ok := fieldValue.Interface().(*time.Time) + if !ok { + return nil, fmt.Errorf("could not assert time.Time") + } if tm.IsZero() && omitEmpty { continue @@ -364,6 +438,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else { node.Attributes[args[1]], err = json.Marshal(tm.Unix()) } + if err != nil { return nil, err } @@ -383,20 +458,24 @@ func visitModelNode(model interface{}, included *map[string]*Node, // We need to pass a pointer value ptr := reflect.New(fieldValue.Type()) ptr.Elem().Set(fieldValue) + n, err1 := visitModelNode(ptr.Interface(), nil, false) if err1 != nil { return nil, err1 } + node.Attributes[args[1]], err = json.Marshal(n.Attributes) } else if fieldValue.Type().Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct { n, err1 := visitModelNode(fieldValue.Interface(), nil, false) if err1 != nil { return nil, err1 } + node.Attributes[args[1]], err = json.Marshal(n.Attributes) } else { node.Attributes[args[1]], err = json.Marshal(fieldValue.Interface()) } + if err != nil { return nil, err } @@ -441,11 +520,13 @@ func visitModelNode(model interface{}, included *map[string]*Node, er = err break } + relationship.Links = relLinks relationship.Meta = relMeta if sideload { shallowNodes := []*Node{} + for _, n := range relationship.Data { appendIncluded(included, n) shallowNodes = append(shallowNodes, toShallowNode(n)) @@ -461,7 +542,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, } } else { // to-one relationships - // Handle null relationship case if fieldValue.IsNil() { node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} @@ -480,6 +560,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if sideload { appendIncluded(included, relationship) + node.Relationships[args[1]] = &RelationshipOneNode{ Data: toShallowNode(relationship), Links: relLinks, @@ -493,7 +574,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, } } } - } else { er = ErrBadJSONAPIStructTag break @@ -509,6 +589,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if er := jl.validate(); er != nil { return nil, er } + node.Links = linkableModel.JSONAPILinks() } @@ -564,6 +645,7 @@ func nodeMapValues(m *map[string]*Node) []*Node { nodes := make([]*Node, len(mp)) i := 0 + for _, n := range mp { nodes[i] = n i++ @@ -577,9 +659,11 @@ func convertToSliceInterface(i *interface{}) ([]interface{}, error) { if vals.Kind() != reflect.Slice { return nil, ErrExpectedSlice } + var response []interface{} for x := 0; x < vals.Len(); x++ { response = append(response, vals.Index(x).Interface()) } + return response, nil } diff --git a/http/jsonapi/response_test.go b/http/jsonapi/response_test.go index 405932f46..585739f61 100644 --- a/http/jsonapi/response_test.go +++ b/http/jsonapi/response_test.go @@ -13,11 +13,10 @@ import ( "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/pace/bricks/pkg/isotime" - - "github.com/shopspring/decimal" ) func TestMarshalPayload(t *testing.T) { @@ -25,14 +24,16 @@ func TestMarshalPayload(t *testing.T) { if e != nil { panic(e) } + book := &Book{ID: 1, Decimal1: d} books := []*Book{book, {ID: 2}} + var jsonData map[string]interface{} // One out1 := bytes.NewBuffer(nil) - err := MarshalPayload(out1, book) - if err != nil { + + if err := MarshalPayload(out1, book); err != nil { t.Fatal(err) } @@ -43,21 +44,24 @@ func TestMarshalPayload(t *testing.T) { if err := json.Unmarshal(out1.Bytes(), &jsonData); err != nil { t.Fatal(err) } + if _, ok := jsonData["data"].(map[string]interface{}); !ok { t.Fatalf("data key did not contain an Hash/Dict/Map") } + fmt.Println(out1.String()) // Many out2 := bytes.NewBuffer(nil) - err = MarshalPayload(out2, books) - if err != nil { + + if err := MarshalPayload(out2, books); err != nil { t.Fatal(err) } if err := json.Unmarshal(out2.Bytes(), &jsonData); err != nil { t.Fatal(err) } + if _, ok := jsonData["data"].([]interface{}); !ok { t.Fatalf("data key did not contain an Array") } @@ -65,6 +69,7 @@ func TestMarshalPayload(t *testing.T) { func TestMarshalPayloadWithNulls(t *testing.T) { books := []*Book{nil, {ID: 101}, nil} + var jsonData map[string]interface{} out := bytes.NewBuffer(nil) @@ -75,14 +80,17 @@ func TestMarshalPayloadWithNulls(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + raw, ok := jsonData["data"] if !ok { t.Fatalf("data key does not exist") } + arr, ok := raw.([]interface{}) if !ok { t.Fatalf("data is not an Array") } + for i := 0; i < len(arr); i++ { if books[i] == nil && arr[i] != nil || books[i] != nil && arr[i] == nil { @@ -105,15 +113,35 @@ func TestMarshal_attrStringSlice(t *testing.T) { t.Fatal(err) } - jsonTags := jsonData["data"].(map[string]interface{})["attributes"].(map[string]interface{})["tags"].([]interface{}) + dataMap, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } + + attributesMap, ok := dataMap["attributes"].(map[string]interface{}) + if !ok { + t.Fatal("data.attributes was not a map") + } + + jsonTags, ok := attributesMap["tags"].([]interface{}) + if !ok { + t.Fatal("data.attributes.tags was not a slice") + } + if e, a := len(tags), len(jsonTags); e != a { t.Fatalf("Was expecting tags of length %d got %d", e, a) } // Convert from []interface{} to []string jsonTagsStrings := []string{} + for _, tag := range jsonTags { - jsonTagsStrings = append(jsonTagsStrings, tag.(string)) + s, ok := tag.(string) + if !ok { + t.Fatalf("Was expecting tag to be a string, got %T", tag) + } + + jsonTagsStrings = append(jsonTagsStrings, s) } // Sort both @@ -139,25 +167,38 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - relationships := jsonData["data"].(map[string]interface{})["relationships"].(map[string]interface{}) + + dataMap, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } + + relationships, ok := dataMap["relationships"].(map[string]interface{}) + if !ok { + t.Fatal("data.relationships was not a map") + } // Verifiy the "posts" relation was an empty array posts, ok := relationships["posts"] if !ok { t.Fatal("Was expecting the data.relationships.posts key/value to have been present") } + postsMap, ok := posts.(map[string]interface{}) if !ok { t.Fatal("data.relationships.posts was not a map") } + postsData, ok := postsMap["data"] if !ok { t.Fatal("Was expecting the data.relationships.posts.data key/value to have been present") } + postsDataSlice, ok := postsData.([]interface{}) if !ok { t.Fatal("data.relationships.posts.data was not a slice []") } + if len(postsDataSlice) != 0 { t.Fatal("Was expecting the data.relationships.posts.data value to have been an empty array []") } @@ -167,14 +208,17 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { if !postExists { t.Fatal("Was expecting the data.relationships.current_post key/value to have NOT been omitted") } + currentPostMap, ok := currentPost.(map[string]interface{}) if !ok { t.Fatal("data.relationships.current_post was not a map") } + currentPostData, ok := currentPostMap["data"] if !ok { t.Fatal("Was expecting the data.relationships.current_post.data key/value to have been present") } + if currentPostData != nil { t.Fatal("Was expecting the data.relationships.current_post.data value to have been nil/null") } @@ -199,7 +243,11 @@ func TestWithOmitsEmptyAnnotationOnRelation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - payload := jsonData["data"].(map[string]interface{}) + + payload, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } // Verify relationship was NOT set if val, exists := payload["relationships"]; exists { @@ -231,20 +279,34 @@ func TestWithOmitsEmptyAnnotationOnRelation_MixedData(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - payload := jsonData["data"].(map[string]interface{}) + + payload, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } // Verify relationship was set if _, exists := payload["relationships"]; !exists { t.Fatal("Was expecting the data.relationships key/value to have NOT been empty") } - relationships := payload["relationships"].(map[string]interface{}) + relationships, ok := payload["relationships"].(map[string]interface{}) + if !ok { + t.Fatal("data.relationships was not a map") + } // Verify the relationship was not omitted, and is not null if val, exists := relationships["current_post"]; !exists { t.Fatal("Was expecting the data.relationships.current_post key/value to have NOT been omitted") - } else if val.(map[string]interface{})["data"] == nil { - t.Fatal("Was expecting the data.relationships.current_post value to have NOT been nil/null") + } else { + valMap, ok := val.(map[string]interface{}) + if !ok { + t.Fatal("Was expecting the data.relationships.current_post value to have been a map") + } + + if valMap["data"] == nil { + t.Fatal("Was expecting the data.relationships.current_post value to have NOT been nil/null") + } } } @@ -289,20 +351,32 @@ func TestWithOmitsEmptyAnnotationOnAttribute(t *testing.T) { } // Verify that there is no field "phones" in attributes - payload := jsonData["data"].(map[string]interface{}) - attributes := payload["attributes"].(map[string]interface{}) + payload, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } + + attributes, ok := payload["attributes"].(map[string]interface{}) + if !ok { + t.Fatal("Was expecting the data.attributes key/value to have been a map") + } + if _, ok := attributes["title"]; !ok { t.Fatal("Was expecting the data.attributes.title to have NOT been omitted") } + if _, ok := attributes["phones"]; ok { t.Fatal("Was expecting the data.attributes.phones to have been omitted") } + if _, ok := attributes["address"]; ok { t.Fatal("Was expecting the data.attributes.phones to have been omitted") } + if _, ok := attributes["tags"]; !ok { t.Fatal("Was expecting the data.attributes.tags to have NOT been omitted") } + if _, ok := attributes["account"]; !ok { t.Fatal("Was expecting the data.attributes.account to have NOT been omitted") } @@ -325,14 +399,18 @@ func TestMarshalIDPtr(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - data := jsonData["data"].(map[string]interface{}) - // attributes := data["attributes"].(map[string]interface{}) + + data, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } // Verify that the ID was sent val, exists := data["id"] if !exists { t.Fatal("Was expecting the data.id member to exist") } + if val != id { t.Fatalf("Was expecting the data.id member to be `%s`, got `%s`", id, val) } @@ -346,6 +424,7 @@ func TestMarshalOnePayload_omitIDString(t *testing.T) { foo := &Foo{Title: "Foo"} out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, foo); err != nil { t.Fatal(err) } @@ -354,11 +433,15 @@ func TestMarshalOnePayload_omitIDString(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - payload := jsonData["data"].(map[string]interface{}) + + payload, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } // Verify that empty ID of type string gets omitted. See: // https://github.com/google/jsonapi/issues/83#issuecomment-285611425 - _, ok := payload["id"] + _, ok = payload["id"] if ok { t.Fatal("Was expecting the data.id member to be omitted") } @@ -368,15 +451,14 @@ func TestMarshall_invalidIDType(t *testing.T) { type badIDStruct struct { ID *bool `jsonapi:"primary,cars"` } + id := true o := &badIDStruct{ID: &id} out := bytes.NewBuffer(nil) - if err := MarshalPayload(out, o); err != ErrBadJSONAPIID { - t.Fatalf( - "Was expecting a `%s` error, got `%s`", ErrBadJSONAPIID, err, - ) - } + + err := MarshalPayload(out, o) + require.ErrorIs(t, err, ErrBadJSONAPIID) } func TestOmitsEmptyAnnotation(t *testing.T) { @@ -394,12 +476,22 @@ func TestOmitsEmptyAnnotation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } - attributes := jsonData["data"].(map[string]interface{})["attributes"].(map[string]interface{}) + + data, ok := jsonData["data"].(map[string]interface{}) + if !ok { + t.Fatal("data was not a map") + } + + attributes, ok := data["attributes"].(map[string]interface{}) + if !ok { + t.Fatal("Was expecting the data.attributes key/value to have been a map") + } // Verify that the specifically omitted field were omitted if val, exists := attributes["title"]; exists { t.Fatalf("Was expecting the data.attributes.title key/value to have been omitted - it was not and had a value of %v", val) } + if val, exists := attributes["pages"]; exists { t.Fatalf("Was expecting the data.attributes.pages key/value to have been omitted - it was not and had a value of %v", val) } @@ -664,12 +756,14 @@ func TestSupportsLinkable(t *testing.T) { if data.Links == nil { t.Fatal("Expected data.links") } + links := *data.Links self, hasSelf := links["self"] if !hasSelf { t.Fatal("Expected 'self' link to be present") } + if _, isString := self.(string); !isString { t.Fatal("Expected 'self' to contain a string") } @@ -678,6 +772,7 @@ func TestSupportsLinkable(t *testing.T) { if !hasComments { t.Fatal("expect 'comments' to be present") } + commentsMap, isMap := comments.(map[string]interface{}) if !isMap { t.Fatal("Expected 'comments' to contain a map") @@ -687,6 +782,7 @@ func TestSupportsLinkable(t *testing.T) { if !hasHref { t.Fatal("Expect 'comments' to contain an 'href' key/value") } + if _, isString := commentsHref.(string); !isString { t.Fatal("Expected 'href' to contain a string") } @@ -695,16 +791,19 @@ func TestSupportsLinkable(t *testing.T) { if !hasMeta { t.Fatal("Expect 'comments' to contain a 'meta' key/value") } + commentsMetaMap, isMap := commentsMeta.(map[string]interface{}) if !isMap { t.Fatal("Expected 'comments' to contain a map") } commentsMetaObject := Meta(commentsMetaMap) + countsMap, isMap := commentsMetaObject["counts"].(map[string]interface{}) if !isMap { t.Fatal("Expected 'counts' to contain a map") } + for k, v := range countsMap { if _, isNum := v.(float64); !isNum { t.Fatalf("Exepected value at '%s' to be a numeric (float64)", k) @@ -746,7 +845,7 @@ func TestSupportsMetable(t *testing.T) { t.Fatalf("Expected data.meta") } - meta := Meta(*data.Meta) + meta := *data.Meta if e, a := "extra details regarding the blog", meta["detail"]; e != a { t.Fatalf("Was expecting meta.detail to be %q, got %q", e, a) } @@ -774,10 +873,11 @@ func TestRelations(t *testing.T) { if relations["posts"] == nil { t.Fatalf("Posts relationship was not materialized") } else { - if relations["posts"].(map[string]interface{})["links"] == nil { + if posts, ok := relations["posts"].(map[string]interface{}); !ok || posts["links"] == nil { t.Fatalf("Posts relationship links were not materialized") } - if relations["posts"].(map[string]interface{})["meta"] == nil { + + if posts, ok := relations["posts"].(map[string]interface{}); !ok || posts["meta"] == nil { t.Fatalf("Posts relationship meta were not materialized") } } @@ -785,15 +885,26 @@ func TestRelations(t *testing.T) { if relations["current_post"] == nil { t.Fatalf("Current post relationship was not materialized") } else { - if relations["current_post"].(map[string]interface{})["links"] == nil { + if currentPost, ok := relations["current_post"].(map[string]interface{}); !ok || currentPost["links"] == nil { t.Fatalf("Current post relationship links were not materialized") } - if relations["current_post"].(map[string]interface{})["meta"] == nil { + + if currentPost, ok := relations["current_post"].(map[string]interface{}); !ok || currentPost["meta"] == nil { t.Fatalf("Current post relationship meta were not materialized") } } - if len(relations["posts"].(map[string]interface{})["data"].([]interface{})) != 2 { + posts, ok := relations["posts"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected posts to be a map") + } + + postsData, ok := posts["data"].([]interface{}) + if !ok { + t.Fatalf("Expected posts.data to be a slice") + } + + if len(postsData) != 2 { t.Fatalf("Did not materialize two posts") } } @@ -975,6 +1086,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { {ID: 2, Author: "shwoodard", ISBN: "xyz"}, } interfaces := []interface{}{} + for _, s := range structs { interfaces = append(interfaces, s) } @@ -984,6 +1096,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { if err := MarshalPayload(structsOut, structs); err != nil { t.Fatal(err) } + interfacesOut := new(bytes.Buffer) if err := MarshalPayload(interfacesOut, interfaces); err != nil { t.Fatal(err) @@ -994,6 +1107,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { if err := json.Unmarshal(structsOut.Bytes(), &structsData); err != nil { t.Fatal(err) } + if err := json.Unmarshal(interfacesOut.Bytes(), &interfacesData); err != nil { t.Fatal(err) } @@ -1006,15 +1120,15 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { func TestMarshal_InvalidIntefaceArgument(t *testing.T) { out := new(bytes.Buffer) - if err := MarshalPayload(out, true); err != ErrUnexpectedType { - t.Fatal("Was expecting an error") - } - if err := MarshalPayload(out, 25); err != ErrUnexpectedType { - t.Fatal("Was expecting an error") - } - if err := MarshalPayload(out, Book{}); err != ErrUnexpectedType { - t.Fatal("Was expecting an error") - } + + err := MarshalPayload(out, true) + require.ErrorIs(t, err, ErrUnexpectedType) + + err = MarshalPayload(out, 25) + require.ErrorIs(t, err, ErrUnexpectedType) + + err = MarshalPayload(out, Book{}) + require.ErrorIs(t, err, ErrUnexpectedType) } func testBlog() *Blog { diff --git a/http/jsonapi/runtime.go b/http/jsonapi/runtime.go index 7fd67db16..d7f382f33 100644 --- a/http/jsonapi/runtime.go +++ b/http/jsonapi/runtime.go @@ -106,6 +106,7 @@ func (r *Runtime) instrumentCall(start Event, stop Event, c func() error) error } begin := time.Now() + Instrumentation(r, start, instrumentationGUID, time.Duration(0)) if err := c(); err != nil { @@ -128,5 +129,6 @@ func newUUID() (string, error) { uuid[8] = uuid[8]&^0xc0 | 0x80 // version 4 (pseudo-random); see section 4.1.3 uuid[6] = uuid[6]&^0xf0 | 0x40 + return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]), nil } diff --git a/http/jsonapi/runtime/consts.go b/http/jsonapi/runtime/consts.go index 56934f9ba..9c6b11888 100644 --- a/http/jsonapi/runtime/consts.go +++ b/http/jsonapi/runtime/consts.go @@ -3,5 +3,5 @@ package runtime // JSONAPIContentType is the content type required for -// jsonapi based requests and responses +// jsonapi based requests and responses. const JSONAPIContentType = "application/vnd.api+json" diff --git a/http/jsonapi/runtime/error.go b/http/jsonapi/runtime/error.go index ecacebf7c..83da29a10 100644 --- a/http/jsonapi/runtime/error.go +++ b/http/jsonapi/runtime/error.go @@ -4,6 +4,7 @@ package runtime import ( "encoding/json" + "errors" "net/http" "strconv" "strings" @@ -39,29 +40,30 @@ type Error struct { Meta *map[string]interface{} `json:"meta,omitempty"` } -// setHttpStatus sets the http status for the error object +// setHttpStatus sets the http status for the error object. func (e *Error) setHTTPStatus(code int) { e.Status = strconv.Itoa(code) } -// Error implements the error interface +// Error implements the error interface. func (e Error) Error() string { return e.Title } -// Errors is a list of errors +// Errors is a list of errors. type Errors []*Error -// Error implements the error interface +// Error implements the error interface. func (e Errors) Error() string { messages := make([]string, len(e)) for i, err := range e { messages[i] = err.Error() } + return strings.Join(messages, "\n") } -// setHttpStatus sets the http status for the error object +// setHttpStatus sets the http status for the error object. func (e Errors) setHTTPStatus(code int) { status := strconv.Itoa(code) for _, err := range e { @@ -69,29 +71,34 @@ func (e Errors) setHTTPStatus(code int) { } } -// setID sets the error id on the request +// setID sets the error id on the request. func (e Errors) setID(errorID string) { for _, err := range e { err.ID = errorID } } -// WriteError writes a jsonapi error message to the client +// WriteError writes a jsonapi error message to the client. func WriteError(w http.ResponseWriter, code int, err error) { w.Header().Set("Content-Type", JSONAPIContentType) w.WriteHeader(code) - // convert error type for marshaling - var errList errorObjects - - switch v := err.(type) { - case Error: - errList.List = append(errList.List, &v) - case *Error: - errList.List = append(errList.List, v) - case Errors: - errList.List = v - default: + var ( + // convert error type for marshaling + errList errorObjects + + errError Error + errErrorPtr *Error + errErrors Errors + ) + + if errors.As(err, &errError) { + errList.List = append(errList.List, &errError) + } else if errors.As(err, &errErrorPtr) { + errList.List = append(errList.List, errErrorPtr) + } else if errors.As(err, &errErrors) { + errList.List = errErrors + } else { errList.List = []*Error{ {Title: err.Error()}, } @@ -110,8 +117,8 @@ func WriteError(w http.ResponseWriter, code int, err error) { // render the error to the client enc := json.NewEncoder(w) enc.SetIndent("", " ") - err = enc.Encode(errList) - if err != nil { + + if err := enc.Encode(errList); err != nil { log.Logger().Info().Str("req_id", reqID). Err(err).Msg("Unable to send error response to the client") } diff --git a/http/jsonapi/runtime/error_test.go b/http/jsonapi/runtime/error_test.go index 77c124067..1a410ee27 100644 --- a/http/jsonapi/runtime/error_test.go +++ b/http/jsonapi/runtime/error_test.go @@ -9,6 +9,8 @@ import ( "net/http/httptest" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestErrorMarshaling(t *testing.T) { @@ -46,19 +48,25 @@ func TestErrorMarshaling(t *testing.T) { WriteError(rec, testCase.httpStatus, testCase.err) resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() if resp.StatusCode != testCase.httpStatus { t.Errorf("expected the response code %d got: %d", testCase.httpStatus, resp.StatusCode) } + if ct := resp.Header.Get("Content-Type"); ct != JSONAPIContentType { t.Errorf("expected the response code %q got: %q", JSONAPIContentType, ct) } var errList errorObjects + dec := json.NewDecoder(resp.Body) - err := dec.Decode(&errList) - if err != nil { + + if err := dec.Decode(&errList); err != nil { t.Fatal(err) } @@ -82,6 +90,7 @@ func TestErrors(t *testing.T) { &Error{Title: "foo2", Detail: "bar2"}, } result := "foo\nfoo2" + if errs.Error() != result { t.Errorf("expected %q got: %q", result, errs.Error()) } @@ -89,7 +98,7 @@ func TestErrors(t *testing.T) { func TestError(t *testing.T) { err := Error{} - err.setHTTPStatus(200) + err.setHTTPStatus(http.StatusOK) result := "200" if err.Status != result { diff --git a/http/jsonapi/runtime/marshalling.go b/http/jsonapi/runtime/marshalling.go index c28a2ed80..2eced9519 100644 --- a/http/jsonapi/runtime/marshalling.go +++ b/http/jsonapi/runtime/marshalling.go @@ -3,6 +3,7 @@ package runtime import ( + "errors" "fmt" "net" "net/http" @@ -15,10 +16,12 @@ import ( // Unmarshal processes the request content and fills passed data struct with the // correct jsonapi content. After un-marshaling the struct will be validated with // specified go-validator struct tags. -// In case of an error, an jsonapi error message will be directly send to the client +// In case of an error, an jsonapi error message will be directly send to the client. func Unmarshal(w http.ResponseWriter, r *http.Request, data interface{}) bool { // don't leak , but error can't be handled - defer r.Body.Close() // nolint: errcheck + defer func() { + _ = r.Body.Close() + }() // verify that the client accepts our response // Note: logically this would be done before marshalling, @@ -40,10 +43,9 @@ func Unmarshal(w http.ResponseWriter, r *http.Request, data interface{}) bool { } // parse request - err := jsonapi.UnmarshalPayload(r.Body, data) - if err != nil { + if err := jsonapi.UnmarshalPayload(r.Body, data); err != nil { WriteError(w, http.StatusUnprocessableEntity, - fmt.Errorf("can't parse content: %v", err)) + fmt.Errorf("can't parse content: %w", err)) return false } @@ -54,10 +56,12 @@ func Unmarshal(w http.ResponseWriter, r *http.Request, data interface{}) bool { // UnmarshalMany processes the request content that has an array of objects and fills passed data struct with the // correct jsonapi content. After un-marshaling the struct will be validated with // specified go-validator struct tags. -// In case of an error, an jsonapi error message will be directly send to the client +// In case of an error, an jsonapi error message will be directly send to the client. func UnmarshalMany(w http.ResponseWriter, r *http.Request, t reflect.Type) (bool, []interface{}) { // don't leak , but error can't be handled - defer r.Body.Close() // nolint: errcheck + defer func() { + _ = r.Body.Close() + }() // verify that the client accepts our response // Note: logically this would be done before marshalling, @@ -82,7 +86,7 @@ func UnmarshalMany(w http.ResponseWriter, r *http.Request, t reflect.Type) (bool data, err := jsonapi.UnmarshalManyPayload(r.Body, t) if err != nil { WriteError(w, http.StatusUnprocessableEntity, - fmt.Errorf("can't parse content: %v", err)) + fmt.Errorf("can't parse content: %w", err)) return false, nil } // validate request @@ -91,24 +95,24 @@ func UnmarshalMany(w http.ResponseWriter, r *http.Request, t reflect.Type) (bool return false, nil } } + return true, data } // Marshal the given data and writes them into the response writer, sets -// the content-type and code as well +// the content-type and code as well. func Marshal(w http.ResponseWriter, data interface{}, code int) { // write response header w.Header().Set("Content-Type", JSONAPIContentType) w.WriteHeader(code) // write marshaled response body - err := jsonapi.MarshalPayload(w, data) - if err != nil { - switch err.(type) { - case *net.OpError: - log.Errorf("Connection error: %s", err) - default: - panic(fmt.Errorf("failed to marshal jsonapi response for %#v: %s", data, err)) + if err := jsonapi.MarshalPayload(w, data); err != nil { + var opErr *net.OpError + if errors.As(err, &opErr) { + log.Errorf("Connection error: %v", err) + } else { + panic(fmt.Errorf("failed to marshal jsonapi response for %#v: %w", data, err)) } } } diff --git a/http/jsonapi/runtime/marshalling_test.go b/http/jsonapi/runtime/marshalling_test.go index 36ae03aaf..de64a9b5e 100644 --- a/http/jsonapi/runtime/marshalling_test.go +++ b/http/jsonapi/runtime/marshalling_test.go @@ -11,11 +11,13 @@ import ( "reflect" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestUnmarshalAccept(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) ok := Unmarshal(rec, req, nil) if ok { @@ -23,7 +25,12 @@ func TestUnmarshalAccept(t *testing.T) { } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusNotAcceptable { t.Errorf("Expected status code %d got: %d", http.StatusNotAcceptable, resp.StatusCode) } @@ -31,7 +38,7 @@ func TestUnmarshalAccept(t *testing.T) { func TestUnmarshalContentType(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) req.Header.Set("Accept", JSONAPIContentType) ok := Unmarshal(rec, req, nil) @@ -40,7 +47,12 @@ func TestUnmarshalContentType(t *testing.T) { } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusUnsupportedMediaType { t.Errorf("Expected status code %d got: %d", http.StatusUnsupportedMediaType, resp.StatusCode) } @@ -48,7 +60,7 @@ func TestUnmarshalContentType(t *testing.T) { func TestUnmarshalContent(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data": 1}`)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data": 1}`)) req.Header.Set("Accept", JSONAPIContentType) req.Header.Set("Content-Type", JSONAPIContentType) @@ -58,13 +70,19 @@ func TestUnmarshalContent(t *testing.T) { } var article Article + ok := Unmarshal(rec, req, &article) if ok { t.Error("Un-marshalling should fail") } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusUnprocessableEntity { t.Errorf("Expected status code %d got: %d", http.StatusUnprocessableEntity, resp.StatusCode) } @@ -72,7 +90,7 @@ func TestUnmarshalContent(t *testing.T) { func TestUnmarshalArticle(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":{ + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data":{ "type": "articles", "id": "cb855aff-f03c-4307-9a22-ab5fcc6b6d7c", "attributes": { @@ -88,20 +106,27 @@ func TestUnmarshalArticle(t *testing.T) { } var article Article - ok := Unmarshal(rec, req, &article) + ok := Unmarshal(rec, req, &article) if !ok { t.Error("Un-marshalling should have been ok") } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -109,6 +134,7 @@ func TestUnmarshalArticle(t *testing.T) { if article.ID != uuid { t.Errorf("article.ID expected %q got: %q", uuid, article.ID) } + if article.Title != "This is my first blog" { t.Errorf("article.ID expected \"This is my first blog\" got: %q", article.Title) } @@ -116,7 +142,7 @@ func TestUnmarshalArticle(t *testing.T) { func TestUnmarshalArticles(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":[ + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data":[ { "type":"article", "id": "82180c8d-0ab6-4946-9298-61d3c8d13da4", @@ -134,31 +160,39 @@ func TestUnmarshalArticles(t *testing.T) { ]}`)) req.Header.Set("Accept", JSONAPIContentType) req.Header.Set("Content-Type", JSONAPIContentType) + type Article struct { ID string `jsonapi:"primary,article" valid:"optional,uuid"` Title string `jsonapi:"attr,title" valid:"required"` } ok, articles := UnmarshalMany(rec, req, reflect.TypeOf(new(Article))) - if !ok { t.Error("Un-marshalling many should have been ok") } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } if len(articles) != 2 { t.Errorf("Expected 2 articles, got %d", len(articles)) } + expected := []*Article{ { ID: "82180c8d-0ab6-4946-9298-61d3c8d13da4", @@ -169,11 +203,17 @@ func TestUnmarshalArticles(t *testing.T) { Title: "This is the second article", }, } + for i := range articles { - got := articles[i].(*Article) + got, ok := articles[i].(*Article) + if !ok { + t.Errorf("Expected type *Article got: %T", articles[i]) + } + if expected[i].ID != got.ID { t.Errorf("article.ID expected %q got: %q", expected[i].ID, got.ID) } + if expected[i].Title != got.Title { t.Errorf("article.ID expected \"%s\" got: %q", expected[i].ID, got.Title) } @@ -195,7 +235,12 @@ func TestMarshalArticle(t *testing.T) { Marshal(rec, &article, http.StatusOK) resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) } @@ -245,6 +290,7 @@ func TestMarshalConnectionError(t *testing.T) { t.Fatal("Was not expecting a panic") } }() + rec := writer{} Marshal(rec, &struct{}{}, http.StatusOK) } diff --git a/http/jsonapi/runtime/parameters.go b/http/jsonapi/runtime/parameters.go index 54f661b89..c4f18cdfa 100644 --- a/http/jsonapi/runtime/parameters.go +++ b/http/jsonapi/runtime/parameters.go @@ -14,19 +14,19 @@ import ( "github.com/pace/bricks/pkg/isotime" ) -// ScanIn help to avoid missuse using iota for the possible values +// ScanIn help to avoid missuse using iota for the possible values. type ScanIn int const ( - // ScanInPath hints the scanner to scan the input + // ScanInPath hints the scanner to scan the input. ScanInPath ScanIn = iota - // ScanInQuery hints the scanner to scan the request url query + // ScanInQuery hints the scanner to scan the request url query. ScanInQuery - // ScanInHeader ints the scanner to scan the request header + // ScanInHeader ints the scanner to scan the request header. ScanInHeader ) -// ScanParameter configured the ScanParameters function +// ScanParameter configured the ScanParameters function. type ScanParameter struct { // Data contains the reference to the parameter, that should // be scanned to @@ -39,7 +39,7 @@ type ScanParameter struct { Name string } -// BuildInvalidValueError build a new error, using the passed type and data +// BuildInvalidValueError build a new error, using the passed type and data. func (p *ScanParameter) BuildInvalidValueError(typ reflect.Type, data string) error { return &Error{ Title: fmt.Sprintf("invalid value for %s", p.Name), @@ -53,7 +53,7 @@ func (p *ScanParameter) BuildInvalidValueError(typ reflect.Type, data string) er // ScanParameters scans the request using the given path parameter objects // in case an error is encountered a 400 along with a jsonapi errors object // is sent to the ResponseWriter and false is returned. Returns true if all -// values were scanned successfully. The used scanning function is fmt.Sscan +// values were scanned successfully. The used scanning function is fmt.Sscan. func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanParameter) bool { for _, param := range parameters { var scanData string @@ -72,6 +72,7 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP size := len(input) array := reflect.MakeSlice(reValue.Type(), size, size) invalid := 0 + for i := 0; i < size; i++ { if input[i] == "" { invalid++ @@ -79,7 +80,8 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP } arrElem := array.Index(i - invalid) - n, _ := Scan(input[i], arrElem.Addr().Interface()) // nolint: gosec + n, _ := Scan(input[i], arrElem.Addr().Interface()) + if n != 1 { WriteError(w, http.StatusBadRequest, param.BuildInvalidValueError(arrElem.Type(), input[i])) return false @@ -89,6 +91,7 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP if invalid > 0 { array = array.Slice(0, size-invalid) } + reValue.Set(array) // skip parsing at the bottom of the loop @@ -110,10 +113,11 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP return false } } + return true } -// Scan works like fmt.Sscan except for strings and decimals, they are directly assigned +// Scan works like fmt.Sscan except for strings and decimals, they are directly assigned. func Scan(str string, data interface{}) (int, error) { // handle decimal if d, ok := data.(*decimal.Decimal); ok { @@ -121,7 +125,9 @@ func Scan(str string, data interface{}) (int, error) { if err != nil { return 0, err } + *d = nd + return 1, nil } @@ -133,6 +139,7 @@ func Scan(str string, data interface{}) (int, error) { } *t = nt + return 1, nil } @@ -143,5 +150,5 @@ func Scan(str string, data interface{}) (int, error) { return 1, nil } - return fmt.Sscan(str, data) // nolint: gosec + return fmt.Sscan(str, data) } diff --git a/http/jsonapi/runtime/parameters_test.go b/http/jsonapi/runtime/parameters_test.go index d5dcfe994..3d8912645 100644 --- a/http/jsonapi/runtime/parameters_test.go +++ b/http/jsonapi/runtime/parameters_test.go @@ -4,9 +4,12 @@ package runtime import ( "encoding/json" + "net/http" "net/http/httptest" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestScanStringParametersInQuery(t *testing.T) { @@ -21,11 +24,12 @@ func TestScanStringParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 string + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "q"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -50,11 +54,12 @@ func TestScanTimeParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 time.Time + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "q"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -80,11 +85,12 @@ func TestScanBoolParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 bool + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "b"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -96,20 +102,33 @@ func TestScanBoolParametersInQuery(t *testing.T) { } func TestScanNumericParametersInPath(t *testing.T) { - req := httptest.NewRequest("GET", "/foo/", nil) + req := httptest.NewRequest(http.MethodGet, "/foo/", nil) rec := httptest.NewRecorder() + var param0 uint + var param1 uint8 + var param2 uint16 + var param3 uint32 + var param4 uint64 + var param10 int + var param11 int8 + var param12 int16 + var param13 int32 + var param14 int64 + var param20 float32 + var param21 float64 + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInPath, "12", "num"}, &ScanParameter{¶m1, ScanInPath, "12", "num"}, @@ -134,15 +153,19 @@ func TestScanNumericParametersInPath(t *testing.T) { if param0 != uint(12) { t.Errorf("expected parsing result %#v got: %#v", uint(12), param0) } + if param1 != uint8(12) { t.Errorf("expected parsing result %#v got: %#v", uint8(12), param1) } + if param2 != uint16(12) { t.Errorf("expected parsing result %#v got: %#v", uint16(12), param2) } + if param3 != uint32(12) { t.Errorf("expected parsing result %#v got: %#v", uint32(12), param3) } + if param4 != uint64(12) { t.Errorf("expected parsing result %#v got: %#v", uint64(12), param4) } @@ -151,15 +174,19 @@ func TestScanNumericParametersInPath(t *testing.T) { if param10 != int(-12) { t.Errorf("expected parsing result %#v got: %#v", int(-12), param10) } + if param11 != int8(-12) { t.Errorf("expected parsing result %#v got: %#v", int8(-12), param11) } + if param12 != int16(-12) { t.Errorf("expected parsing result %#v got: %#v", int16(-12), param12) } + if param13 != int32(-12) { t.Errorf("expected parsing result %#v got: %#v", int32(-12), param13) } + if param14 != int64(-12) { t.Errorf("expected parsing result %#v got: %#v", int64(-12), param14) } @@ -168,19 +195,26 @@ func TestScanNumericParametersInPath(t *testing.T) { if param20 != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param20) } + if param21 != float64(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float64(-12.123123123123123123123123), param21) } } func TestScanNumericParametersInQueryUint(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=12", nil) rec := httptest.NewRecorder() + var param0 uint + var param1 uint8 + var param2 uint16 + var param3 uint32 + var param4 uint64 + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "num"}, &ScanParameter{¶m1, ScanInQuery, "", "num"}, @@ -198,28 +232,38 @@ func TestScanNumericParametersInQueryUint(t *testing.T) { if param0 != uint(12) { t.Errorf("expected parsing result %#v got: %#v", uint(12), param0) } + if param1 != uint8(12) { t.Errorf("expected parsing result %#v got: %#v", uint8(12), param1) } + if param2 != uint16(12) { t.Errorf("expected parsing result %#v got: %#v", uint16(12), param2) } + if param3 != uint32(12) { t.Errorf("expected parsing result %#v got: %#v", uint32(12), param3) } + if param4 != uint64(12) { t.Errorf("expected parsing result %#v got: %#v", uint64(12), param4) } } func TestScanNumericParametersInQueryInt(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12", nil) rec := httptest.NewRecorder() + var param10 int + var param11 int8 + var param12 int16 + var param13 int32 + var param14 int64 + ok := ScanParameters(rec, req, &ScanParameter{¶m10, ScanInQuery, "", "num"}, &ScanParameter{¶m11, ScanInQuery, "", "num"}, @@ -237,25 +281,32 @@ func TestScanNumericParametersInQueryInt(t *testing.T) { if param10 != int(-12) { t.Errorf("expected parsing result %#v got: %#v", int(-12), param10) } + if param11 != int8(-12) { t.Errorf("expected parsing result %#v got: %#v", int8(-12), param11) } + if param12 != int16(-12) { t.Errorf("expected parsing result %#v got: %#v", int16(-12), param12) } + if param13 != int32(-12) { t.Errorf("expected parsing result %#v got: %#v", int32(-12), param13) } + if param14 != int64(-12) { t.Errorf("expected parsing result %#v got: %#v", int64(-12), param14) } } func TestScanNumericParametersInQueryFloat(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123", nil) rec := httptest.NewRecorder() + var param20 float32 + var param21 float64 + ok := ScanParameters(rec, req, &ScanParameter{¶m20, ScanInQuery, "", "num"}, &ScanParameter{¶m21, ScanInQuery, "", "num"}, @@ -270,15 +321,18 @@ func TestScanNumericParametersInQueryFloat(t *testing.T) { if param20 != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param20) } + if param21 != float64(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float64(-12.123123123123123123123123), param21) } } func TestScanNumericParametersInQueryFloatArray(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123&num=-987.123123123123123123123123&num=", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123&num=-987.123123123123123123123123&num=", nil) rec := httptest.NewRecorder() + var param []float32 + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -296,15 +350,18 @@ func TestScanNumericParametersInQueryFloatArray(t *testing.T) { if param[0] != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param[0]) } + if param[1] != float32(-987.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-987.123123123123123123123123), param[1]) } } func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123&num=stuff", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123&num=stuff", nil) rec := httptest.NewRecorder() + var param []float32 + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -315,12 +372,17 @@ func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() var errList errorObjects + dec := json.NewDecoder(resp.Body) - err := dec.Decode(&errList) - if err != nil { + + if err := dec.Decode(&errList); err != nil { t.Fatal(err) } @@ -332,19 +394,24 @@ func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { if r := "invalid value for num"; errObj.Title != r { t.Errorf("expected title %q got: %q", r, errObj.Title) } + if r := "400"; errObj.Status != r { t.Errorf("expected status %q got: %q", r, errObj.Status) } + if r := "num"; (*errObj.Source)["parameter"] != r { t.Errorf("expected source parameter %q got: %q", r, (*errObj.Source)["parameter"]) } } func TestScanParametersHeader(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("num", "123") + rec := httptest.NewRecorder() + var param int + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInHeader, "", "num"}, ) @@ -366,9 +433,11 @@ func TestScanParametersHeader(t *testing.T) { } func TestScanParametersError(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12", nil) rec := httptest.NewRecorder() + var param uint + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -379,12 +448,17 @@ func TestScanParametersError(t *testing.T) { } resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() var errList errorObjects + dec := json.NewDecoder(resp.Body) - err := dec.Decode(&errList) - if err != nil { + + if err := dec.Decode(&errList); err != nil { t.Fatal(err) } @@ -396,9 +470,11 @@ func TestScanParametersError(t *testing.T) { if r := "invalid value for num"; errObj.Title != r { t.Errorf("expected title %q got: %q", r, errObj.Title) } + if r := "400"; errObj.Status != r { t.Errorf("expected status %q got: %q", r, errObj.Status) } + if r := "num"; (*errObj.Source)["parameter"] != r { t.Errorf("expected source parameter %q got: %q", r, (*errObj.Source)["parameter"]) } diff --git a/http/jsonapi/runtime/standard_params.go b/http/jsonapi/runtime/standard_params.go index c45859d55..6029ed525 100644 --- a/http/jsonapi/runtime/standard_params.go +++ b/http/jsonapi/runtime/standard_params.go @@ -24,8 +24,7 @@ type config struct { var cfg config func init() { - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse jsonapi params from environment: %v", err) } } @@ -37,7 +36,7 @@ type ValueSanitizer interface { SanitizeValue(fieldName string, value string) (interface{}, error) } -// ColumnMapper maps the name of a filter or sorting parameter to a database column name +// ColumnMapper maps the name of a filter or sorting parameter to a database column name. type ColumnMapper interface { // Map maps the value, this function decides if the value is allowed and translates it to a database column name, // the function returns the database column name and a bool that indicates that the value is allowed and mapped @@ -45,25 +44,25 @@ type ColumnMapper interface { } // MapMapper is a very easy ColumnMapper implementation based on a map which contains all allowed values -// and maps them with a map +// and maps them with a map. type MapMapper struct { mapping map[string]string } -// NewMapMapper returns a MapMapper for a specific map +// NewMapMapper returns a MapMapper for a specific map. func NewMapMapper(mapping map[string]string) *MapMapper { return &MapMapper{mapping: mapping} } -// Map returns the mapped value and if it is valid based on a map +// Map returns the mapped value and if it is valid based on a map. func (m *MapMapper) Map(value string) (string, bool) { val, isValid := m.mapping[value] return val, isValid } -// UrlQueryParameters contains all information that is needed for pagination, sorting and filtering. -// It is not depending on orm.Query -type UrlQueryParameters struct { +// URLQueryParameters contains all information that is needed for pagination, sorting and filtering. +// It is not depending on orm.Query. +type URLQueryParameters struct { HasPagination bool PageNr int PageSize int @@ -73,36 +72,46 @@ type UrlQueryParameters struct { // ReadURLQueryParameters reads sorting, filter and pagination from requests and return a UrlQueryParameters object, // even if any errors occur. The returned error combines all errors of pagination, filter and sorting. -func ReadURLQueryParameters(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) (*UrlQueryParameters, error) { - result := &UrlQueryParameters{} +func ReadURLQueryParameters(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) (*URLQueryParameters, error) { + result := &URLQueryParameters{} + var errs []error + if err := result.readPagination(r); err != nil { errs = append(errs, err) } + if err := result.readSorting(r, mapper); err != nil { errs = append(errs, err) } + if err := result.readFilter(r, mapper, sanitizer); err != nil { errs = append(errs, err) } + if len(errs) == 0 { return result, nil } + if len(errs) == 1 { return result, errs[0] } - var errAggregate []string - for _, err := range errs { - errAggregate = append(errAggregate, err.Error()) + + errAggregate := make([]string, len(errs)) + + for i, err := range errs { + errAggregate[i] = err.Error() } + return result, fmt.Errorf("reading URL Query Parameters cased multiple errors: %v", strings.Join(errAggregate, ",")) } -// AddToQuery adds filter, sorting and pagination to a orm.Query -func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { +// AddToQuery adds filter, sorting and pagination to a orm.Query. +func (u *URLQueryParameters) AddToQuery(query *orm.Query) *orm.Query { if u.HasPagination { query.Offset(u.PageSize * u.PageNr).Limit(u.PageSize) } + for name, filterValues := range u.Filter { if len(filterValues) == 0 { continue @@ -112,26 +121,33 @@ func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { query.Where(name+" = ?", filterValues[0]) continue } + query.Where(name+" IN (?)", pg.In(filterValues)) } + for _, val := range u.Order { query.Order(val) } + return query } -func (u *UrlQueryParameters) readPagination(r *http.Request) error { +func (u *URLQueryParameters) readPagination(r *http.Request) error { pageStr := r.URL.Query().Get("page[number]") sizeStr := r.URL.Query().Get("page[size]") + if pageStr == "" { u.HasPagination = false return nil } + u.HasPagination = true + pageNr, err := strconv.Atoi(pageStr) if err != nil { return err } + var pageSize int if sizeStr != "" { pageSize, err = strconv.Atoi(sizeStr) @@ -141,32 +157,40 @@ func (u *UrlQueryParameters) readPagination(r *http.Request) error { } else { pageSize = cfg.DefaultPageSize } + if (pageSize < cfg.MinPageSize) || (pageSize > cfg.MaxPageSize) { return fmt.Errorf("invalid pagesize not between min. and max. value, min: %d, max: %d", cfg.MinPageSize, cfg.MaxPageSize) } + u.PageNr = pageNr u.PageSize = pageSize + return nil } -func (u *UrlQueryParameters) readSorting(r *http.Request, mapper ColumnMapper) error { +func (u *URLQueryParameters) readSorting(r *http.Request, mapper ColumnMapper) error { sort := r.URL.Query().Get("sort") if sort == "" { return nil } + sorting := strings.Split(sort, ",") var order string - var resultedOrders []string - var errSortingWithReason []string + + resultedOrders := make([]string, 0) + errSortingWithReason := make([]string, 0) + for _, val := range sorting { if val == "" { continue } + order = " ASC" if strings.HasPrefix(val, "-") { order = " DESC" } + val = strings.TrimPrefix(val, "-") key, isValid := mapper.Map(val) @@ -174,38 +198,50 @@ func (u *UrlQueryParameters) readSorting(r *http.Request, mapper ColumnMapper) e errSortingWithReason = append(errSortingWithReason, val) continue } + resultedOrders = append(resultedOrders, key+order) } + u.Order = resultedOrders + if len(errSortingWithReason) > 0 { return fmt.Errorf("at least one sorting parameter is not valid: %q", strings.Join(errSortingWithReason, ",")) } + return nil } -func (u *UrlQueryParameters) readFilter(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) error { +func (u *URLQueryParameters) readFilter(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) error { filter := make(map[string][]interface{}) + var invalidFilter []string + for queryName, queryValues := range r.URL.Query() { if !(strings.HasPrefix(queryName, "filter[") && strings.HasSuffix(queryName, "]")) { continue } + key, isValid := getFilterKey(queryName, mapper) if !isValid { invalidFilter = append(invalidFilter, key) continue } + filterValues, isValid := getFilterValues(key, queryValues, sanitizer) if !isValid { invalidFilter = append(invalidFilter, key) continue } + filter[key] = filterValues } + u.Filter = filter + if len(invalidFilter) != 0 { return fmt.Errorf("at least one filter parameter is not valid: %q", strings.Join(invalidFilter, ",")) } + return nil } @@ -213,14 +249,17 @@ func getFilterKey(queryName string, modelMapping ColumnMapper) (string, bool) { field := strings.TrimPrefix(queryName, "filter[") field = strings.TrimSuffix(field, "]") mapped, isValid := modelMapping.Map(field) + if !isValid { return field, false } + return mapped, true } func getFilterValues(fieldName string, queryValues []string, sanitizer ValueSanitizer) ([]interface{}, bool) { var filterValues []interface{} + for _, value := range queryValues { separatedValues := strings.Split(value, ",") for _, separatedValue := range separatedValues { @@ -228,8 +267,10 @@ func getFilterValues(fieldName string, queryValues []string, sanitizer ValueSani if err != nil { return nil, false } + filterValues = append(filterValues, sanitized) } } + return filterValues, true } diff --git a/http/jsonapi/runtime/standard_params_test.go b/http/jsonapi/runtime/standard_params_test.go index 72be89027..22a23fa0f 100644 --- a/http/jsonapi/runtime/standard_params_test.go +++ b/http/jsonapi/runtime/standard_params_test.go @@ -4,6 +4,7 @@ package runtime_test import ( "context" + "net/http" "net/http/httptest" "sort" "testing" @@ -34,6 +35,7 @@ func TestIntegrationFilterParameter(t *testing.T) { // Setup a := assert.New(t) db := setupDatabase(a) + defer func() { // Tear Down err := db.DropTable(&TestModel{}, &orm.DropTableOptions{}) @@ -45,20 +47,24 @@ func TestIntegrationFilterParameter(t *testing.T) { } mapper := runtime.NewMapMapper(mappingNames) // filter - r := httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=b", nil) + r := httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?filter[test]=b", nil) urlParams, err := runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) a.NoError(err) + var modelsFilter []TestModel + q := db.Model(&modelsFilter) q = urlParams.AddToQuery(q) count, _ := q.SelectAndCount() a.Equal(1, count) a.Equal("b", modelsFilter[0].FilterName) - r = httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=a,b", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?filter[test]=a,b", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) a.NoError(err) + var modelsFilter2 []TestModel + q = db.Model(&modelsFilter2) q = urlParams.AddToQuery(q) count, _ = q.SelectAndCount() @@ -70,10 +76,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("b", modelsFilter2[1].FilterName) // Paging - r = httptest.NewRequest("GET", "http://abc.de/whatEver?page[number]=1&page[size]=2", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?page[number]=1&page[size]=2", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsPaging []TestModel + q = db.Model(&modelsPaging) q = urlParams.AddToQuery(q) err = q.Select() @@ -85,10 +93,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("d", modelsPaging[1].FilterName) // Sorting - r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?sort=-test", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsSort []TestModel + q = db.Model(&modelsSort) q = urlParams.AddToQuery(q) err = q.Select() @@ -102,10 +112,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("a", modelsSort[5].FilterName) // Combine all - r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test&filter[test]=a,b,e,f&page[number]=1&page[size]=2", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?sort=-test&filter[test]=a,b,e,f&page[number]=1&page[size]=2", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsCombined []TestModel + q = db.Model(&modelsCombined) q = urlParams.AddToQuery(q) err = q.Select() @@ -121,6 +133,7 @@ func setupDatabase(a *assert.Assertions) *pg.DB { err := db.CreateTable(&TestModel{}, &orm.CreateTableOptions{}) a.NoError(err) + _, err = db.Model(&TestModel{ FilterName: "a", }).Insert() diff --git a/http/jsonapi/runtime/validation.go b/http/jsonapi/runtime/validation.go index 86e4da00e..5195d7d1a 100644 --- a/http/jsonapi/runtime/validation.go +++ b/http/jsonapi/runtime/validation.go @@ -3,12 +3,14 @@ package runtime import ( + "errors" "fmt" "net/http" "strings" "time" valid "github.com/asaskevich/govalidator" + "github.com/pace/bricks/pkg/isotime" ) @@ -28,14 +30,14 @@ func init() { // ValidateParameters checks the given struct and returns true if the struct // is valid according to the specification (declared with go-validator struct tags) -// In case of an error, an jsonapi error message will be directly send to the client +// In case of an error, an jsonapi error message will be directly send to the client. func ValidateParameters(w http.ResponseWriter, r *http.Request, data interface{}) bool { return ValidateStruct(w, r, data, "parameter") } // ValidateRequest checks the given struct and returns true if the struct // is valid according to the specification (declared with go-validator struct tags) -// In case of an error, an jsonapi error message will be directly send to the client +// In case of an error, an jsonapi error message will be directly send to the client. func ValidateRequest(w http.ResponseWriter, r *http.Request, data interface{}) bool { return ValidateStruct(w, r, data, "pointer") } @@ -43,20 +45,19 @@ func ValidateRequest(w http.ResponseWriter, r *http.Request, data interface{}) b // ValidateStruct checks the given struct and returns true if the struct // is valid according to the specification (declared with go-validator struct tags) // In case of an error, an jsonapi error message will be directly send to the client -// The passed source is the source for validation errors (e.g. pointer for data or parameter) +// The passed source is the source for validation errors (e.g. pointer for data or parameter). func ValidateStruct(w http.ResponseWriter, r *http.Request, data interface{}, source string) bool { ok, err := valid.ValidateStruct(data) - if !ok { - switch errs := err.(type) { - case valid.Errors: + validErrors := valid.Errors{} + + if errors.As(err, &validErrors) { var e Errors - generateValidationErrors(errs, &e, source) + + generateValidationErrors(validErrors, &e, source) WriteError(w, http.StatusUnprocessableEntity, e) - case error: - panic(err) // programming error, e.g. not used with struct - default: - panic(fmt.Errorf("unhandled error case: %s", err)) + } else { + panic(fmt.Errorf("unhandled error case: %w", err)) } return false @@ -65,16 +66,21 @@ func ValidateStruct(w http.ResponseWriter, r *http.Request, data interface{}, so return true } -// convert govalidator errors into jsonapi errors +// convert govalidator errors into jsonapi errors. func generateValidationErrors(validErrors valid.Errors, jsonapiErrors *Errors, source string) { for _, err := range validErrors { - switch e := err.(type) { - case valid.Errors: - generateValidationErrors(e, jsonapiErrors, source) - case valid.Error: - *jsonapiErrors = append(*jsonapiErrors, generateValidationError(e, source)) - default: - panic(fmt.Errorf("unhandled error case: %s", e)) + validErrors := valid.Errors{} + + if errors.As(err, &validErrors) { + generateValidationErrors(validErrors, jsonapiErrors, source) + } else { + validError := valid.Error{} + + if errors.As(err, &validError) { + *jsonapiErrors = append(*jsonapiErrors, generateValidationError(validError, source)) + } else { + panic(fmt.Errorf("unhandled error case: %w", err)) + } } } } @@ -88,7 +94,7 @@ func generateValidationErrors(validErrors valid.Errors, jsonapiErrors *Errors, s // https://github.com/pace/bricks/issues/10 // generateValidationError generates a new jsonapi error based -// on the given govalidator error +// on the given govalidator error. func generateValidationError(e valid.Error, source string) *Error { path := "" for _, p := range append(e.Path, e.Name) { diff --git a/http/jsonapi/runtime/validation_test.go b/http/jsonapi/runtime/validation_test.go index cb7c4e8eb..f3bd461ac 100644 --- a/http/jsonapi/runtime/validation_test.go +++ b/http/jsonapi/runtime/validation_test.go @@ -17,10 +17,12 @@ func TestValidateParametersWithError(t *testing.T) { type access struct { Token string `valid:"uuid"` } + type input struct { UUID string `valid:"uuid"` Access access } + expected := map[string]interface{}{ "errors": []interface{}{ map[string]interface{}{ @@ -49,24 +51,27 @@ func TestValidateParametersWithError(t *testing.T) { } rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) ok := ValidateParameters(rec, req, &val) - if ok { t.Error("expected to fail the validation") } resp := rec.Result() - defer resp.Body.Close() - if resp.StatusCode != 422 { + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + + if resp.StatusCode != http.StatusUnprocessableEntity { t.Error("expected UnprocessableEntity") } var data map[string]interface{} - err := json.NewDecoder(resp.Body).Decode(&data) - if err != nil { + + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { t.Fatal(err) } @@ -77,13 +82,14 @@ func TestValidateParametersWithError(t *testing.T) { func TestValidateRequest(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) type args struct { w http.ResponseWriter r *http.Request data interface{} } + tests := []struct { name string args args diff --git a/http/jsonapi/runtime/value_sanitizers.go b/http/jsonapi/runtime/value_sanitizers.go index 38b6c0ccc..9924ed531 100644 --- a/http/jsonapi/runtime/value_sanitizers.go +++ b/http/jsonapi/runtime/value_sanitizers.go @@ -31,6 +31,7 @@ func (d datetimeSanitizer) SanitizeValue(fieldName string, value string) (interf if err != nil { return nil, err } + return t, nil } @@ -58,6 +59,7 @@ func (u uuidSanitizer) SanitizeValue(fieldName string, value string) (interface{ if _, err := uuid.Parse(value); err != nil { return nil, err } + return value, nil } @@ -68,6 +70,7 @@ func (c composableAndFieldRestrictedSanitizer) SanitizeValue(fieldName string, v if !found { return nil, fmt.Errorf("%w: %v", ErrInvalidFieldname, fieldName) } + return san.SanitizeValue(fieldName, value) } diff --git a/http/longpoll/longpoll.go b/http/longpoll/longpoll.go index 782b9721c..ee9478ffa 100644 --- a/http/longpoll/longpoll.go +++ b/http/longpoll/longpoll.go @@ -12,7 +12,7 @@ import ( // longpolling request. type LongPollFunc func(context.Context) (bool, error) -// Config for long polling +// Config for long polling. type Config struct { // RetryTime time to wait between two retries RetryTime time.Duration @@ -23,7 +23,7 @@ type Config struct { } // Default configuration for http long polling -// wait half a second between retries, min 1 sec and max 60 sec +// wait half a second between retries, min 1 sec and max 60 sec. var Default = Config{ RetryTime: time.Millisecond * 500, MinWaitTime: time.Second, @@ -31,7 +31,7 @@ var Default = Config{ } // Until executes the given function fn until duration d is passed or context is canceled. -// The constaints of the Default configuration apply. +// The constraints of the Default configuration apply. func Until(ctx context.Context, d time.Duration, fn LongPollFunc) (ok bool, err error) { return Default.LongPollUntil(ctx, d, fn) } @@ -41,7 +41,7 @@ func Until(ctx context.Context, d time.Duration, fn LongPollFunc) (ok bool, err // be set to the allowed min/max respectively. Other checking is up to the caller. The resulting time // budget is communicated via the provided context. This is a defence measure to not have accidental // long running routines. If no duration is given (0) the long poll will have exactly one execution. -func (c Config) LongPollUntil(ctx context.Context, d time.Duration, fn LongPollFunc) (ok bool, err error) { +func (c Config) LongPollUntil(ctx context.Context, d time.Duration, fn LongPollFunc) (bool, error) { until := time.Now() if d != 0 { @@ -57,11 +57,16 @@ func (c Config) LongPollUntil(ctx context.Context, d time.Duration, fn LongPollF fnCtx, cancel := context.WithDeadline(ctx, until) defer cancel() + var ( + ok bool + err error + ) + loop: for { ok, err = fn(fnCtx) if err != nil { - return + break } // fn returns true, break the loop @@ -88,5 +93,5 @@ loop: } } - return + return ok, err } diff --git a/http/longpoll/longpoll_test.go b/http/longpoll/longpoll_test.go index 289860ab9..da07b5fde 100644 --- a/http/longpoll/longpoll_test.go +++ b/http/longpoll/longpoll_test.go @@ -14,8 +14,10 @@ func TestLongPollUntilBounds(t *testing.T) { ok, err := Until(context.Background(), -1, func(ctx context.Context) (bool, error) { budget, ok := ctx.Deadline() assert.True(t, ok) - assert.Equal(t, time.Millisecond*999, budget.Sub(time.Now()).Truncate(time.Millisecond)) // nolint: gosimple + assert.Equal(t, time.Millisecond*999, time.Until(budget).Truncate(time.Millisecond)) + called++ + return true, nil }) assert.True(t, ok) @@ -26,8 +28,10 @@ func TestLongPollUntilBounds(t *testing.T) { ok, err = Until(context.Background(), time.Hour, func(ctx context.Context) (bool, error) { budget, ok := ctx.Deadline() assert.True(t, ok) - assert.Equal(t, time.Second*59, budget.Sub(time.Now()).Truncate(time.Second)) // nolint: gosimple + assert.Equal(t, time.Second*59, time.Until(budget).Truncate(time.Second)) + called++ + return true, nil }) assert.True(t, ok) @@ -79,8 +83,10 @@ func TestLongPollUntilTimeout(t *testing.T) { func TestLongPollUntilTimeoutWithContext(t *testing.T) { called := 0 + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + ok, err := Until(ctx, time.Second*2, func(context.Context) (bool, error) { called++ return false, nil diff --git a/http/middleware/context.go b/http/middleware/context.go index 35bfc66f9..6a0eaab66 100644 --- a/http/middleware/context.go +++ b/http/middleware/context.go @@ -29,6 +29,7 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if r := requestFromContext(ctx); r != nil { return contextWithRequest(targetCtx, r) } + return targetCtx } @@ -44,8 +45,11 @@ func contextWithRequest(ctx context.Context, ctxReq *ctxRequest) context.Context func requestFromContext(ctx context.Context) *ctxRequest { if v := ctx.Value((*ctxRequest)(nil)); v != nil { - return v.(*ctxRequest) + if request, ok := v.(*ctxRequest); ok { + return request + } } + return nil } @@ -67,21 +71,26 @@ func GetXForwardedForHeaderFromContext(ctx context.Context) (string, error) { if ctxReq == nil { return "", fmt.Errorf("getting request from context: %w", ErrNotFound) } + xForwardedFor := ctxReq.XForwardedFor + ip, _, err := net.SplitHostPort(ctxReq.RemoteAddr) if err != nil { return "", fmt.Errorf( - "%w (from context): could not get ip from remote address: %s", + "%w (from context): could not get ip from remote address: %w", ErrInvalidRequest, err) } + if ip == "" { return "", fmt.Errorf( "%w (from context): could not get ip from remote address: %q", ErrInvalidRequest, ctxReq.RemoteAddr) } + if xForwardedFor != "" { xForwardedFor += ", " } + return xForwardedFor + ip, nil } @@ -93,5 +102,6 @@ func GetUserAgentFromContext(ctx context.Context) (string, error) { if ctxReq == nil { return "", fmt.Errorf("getting request from context: %w", ErrNotFound) } + return ctxReq.UserAgent, nil } diff --git a/http/middleware/context_test.go b/http/middleware/context_test.go index 2f12aee09..026697166 100644 --- a/http/middleware/context_test.go +++ b/http/middleware/context_test.go @@ -8,13 +8,14 @@ import ( "net/http" "testing" - . "github.com/pace/bricks/http/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + . "github.com/pace/bricks/http/middleware" ) func TestContextTransfer(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) r.Header.Set("User-Agent", "Foobar") RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -72,12 +73,14 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) { } for name, c := range cases { t.Run(name, func(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) + r.RemoteAddr = c.RemoteAddr if c.XForwardedFor != "" { r.Header.Set("X-Forwarded-For", c.XForwardedFor) } + RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { ctx := r.Context() xForwardedFor, err := GetXForwardedForHeaderFromContext(ctx) @@ -98,7 +101,7 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) { } func TestGetUserAgentFromContext(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) r.Header.Set("User-Agent", "Foobar") RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { diff --git a/http/middleware/external_dependency.go b/http/middleware/external_dependency.go index 382a54b53..3886db804 100644 --- a/http/middleware/external_dependency.go +++ b/http/middleware/external_dependency.go @@ -14,16 +14,17 @@ import ( "github.com/pace/bricks/maintenance/log" ) -// depFormat is the format of a single dependency report +// depFormat is the format of a single dependency report. const depFormat = "%s:%d" -// ExternalDependencyHeaderName name of the HTTP header that is used for reporting +// ExternalDependencyHeaderName name of the HTTP header that is used for reporting. const ExternalDependencyHeaderName = "External-Dependencies" -// ExternalDependency middleware to report external dependencies +// ExternalDependency middleware to report external dependencies. func ExternalDependency(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var edc ExternalDependencyContext + edw := externalDependencyWriter{ ResponseWriter: w, edc: &edc, @@ -39,6 +40,7 @@ func AddExternalDependency(ctx context.Context, name string, dur time.Duration) log.Ctx(ctx).Warn().Msgf("can't add external dependency %q with %s, because context is missing", name, dur) return } + ec.AddDependency(name, dur) } @@ -48,12 +50,13 @@ type externalDependencyWriter struct { edc *ExternalDependencyContext } -// addHeader adds the external dependency header if not done already +// addHeader adds the external dependency header if not done already. func (w *externalDependencyWriter) addHeader() { if !w.header { if len(w.edc.dependencies) > 0 { w.ResponseWriter.Header().Add(ExternalDependencyHeaderName, w.edc.String()) } + w.header = true } } @@ -68,21 +71,25 @@ func (w *externalDependencyWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } -// ContextWithExternalDependency creates a contex with the external provided dependencies +// ContextWithExternalDependency creates a contex with the external provided dependencies. func ContextWithExternalDependency(ctx context.Context, edc *ExternalDependencyContext) context.Context { return context.WithValue(ctx, (*ExternalDependencyContext)(nil), edc) } -// ExternalDependencyContextFromContext returns the external dependencies context or nil +// ExternalDependencyContextFromContext returns the external dependencies context or nil. func ExternalDependencyContextFromContext(ctx context.Context) *ExternalDependencyContext { if v := ctx.Value((*ExternalDependencyContext)(nil)); v != nil { - return v.(*ExternalDependencyContext) + out, ok := v.(*ExternalDependencyContext) + if ok { + return out + } } + return nil } // ExternalDependencyContext contains all dependencies that were seen -// during the request livecycle +// during the request livecycle. type ExternalDependencyContext struct { mu sync.RWMutex dependencies []externalDependency @@ -97,21 +104,26 @@ func (c *ExternalDependencyContext) AddDependency(name string, duration time.Dur c.mu.Unlock() } -// String formats all external dependencies +// String formats all external dependencies. func (c *ExternalDependencyContext) String() string { var b strings.Builder + sep := len(c.dependencies) - 1 + for _, dep := range c.dependencies { b.WriteString(dep.String()) + if sep > 0 { b.WriteByte(',') + sep-- } } + return b.String() } -// Parse a external dependency value +// Parse a external dependency value. func (c *ExternalDependencyContext) Parse(s string) { values := strings.Split(s, ",") for _, value := range values { @@ -119,6 +131,7 @@ func (c *ExternalDependencyContext) Parse(s string) { if index == -1 { continue // ignore the invalid values } + dur, err := strconv.ParseInt(value[index+1:], 10, 64) if err != nil { continue // ignore the invalid values @@ -129,13 +142,13 @@ func (c *ExternalDependencyContext) Parse(s string) { } // externalDependency represents one external dependency that -// was involved in the process to creating a response +// was involved in the process to creating a response. type externalDependency struct { Name string // canonical name of the source Duration time.Duration // time spend with the external dependency } -// String returns a formated single external dependency +// String returns a formated single external dependency. func (r externalDependency) String() string { return fmt.Sprintf(depFormat, r.Name, r.Duration.Milliseconds()) } diff --git a/http/middleware/external_dependency_test.go b/http/middleware/external_dependency_test.go index 77b668f51..74f357cc8 100644 --- a/http/middleware/external_dependency_test.go +++ b/http/middleware/external_dependency_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_ExternalDependency_Middleare(t *testing.T) { @@ -22,7 +23,15 @@ func Test_ExternalDependency_Middleare(t *testing.T) { w.WriteHeader(http.StatusOK) })) h.ServeHTTP(rec, req) - assert.Nil(t, rec.Result().Header[ExternalDependencyHeaderName]) + + res := rec.Result() + + defer func() { + err := res.Body.Close() + assert.NoError(t, err) + }() + + assert.Nil(t, res.Header[ExternalDependencyHeaderName]) }) t.Run("one dependency set", func(t *testing.T) { rec := httptest.NewRecorder() @@ -30,10 +39,20 @@ func Test_ExternalDependency_Middleare(t *testing.T) { h := ExternalDependency(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { AddExternalDependency(r.Context(), "test", time.Second) - w.Write(nil) // nolint: errcheck + + _, err := w.Write(nil) + require.NoError(t, err) })) h.ServeHTTP(rec, req) - assert.Equal(t, rec.Result().Header[ExternalDependencyHeaderName][0], "test:1000") + + res := rec.Result() + + defer func() { + err := res.Body.Close() + assert.NoError(t, err) + }() + + assert.Equal(t, res.Header[ExternalDependencyHeaderName][0], "test:1000") }) } diff --git a/http/middleware/metrics.go b/http/middleware/metrics.go index 77b21166d..88d71f6f7 100644 --- a/http/middleware/metrics.go +++ b/http/middleware/metrics.go @@ -63,9 +63,11 @@ func Metrics(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { paceHTTPInFlightGauge.Inc() defer paceHTTPInFlightGauge.Dec() + startTime := time.Now() srw := statusWriter{ResponseWriter: w} next.ServeHTTP(&srw, r) + dur := float64(time.Since(startTime)) / float64(time.Millisecond) labels := prometheus.Labels{ "code": strconv.Itoa(srw.status), @@ -91,10 +93,12 @@ func (w *statusWriter) WriteHeader(status int) { func (w *statusWriter) Write(b []byte) (int, error) { if w.status == 0 { - w.status = 200 + w.status = http.StatusOK } + n, err := w.ResponseWriter.Write(b) w.length += n + return n, err } @@ -103,5 +107,6 @@ func filterRequestSource(source string) string { case "uptime", "kubernetes", "nginx", "livetest": return source } + return "" } diff --git a/http/middleware/response_header.go b/http/middleware/response_header.go index 59a6ad158..b3457da3a 100644 --- a/http/middleware/response_header.go +++ b/http/middleware/response_header.go @@ -10,7 +10,7 @@ import ( jwt "github.com/golang-jwt/jwt/v5" ) -// ClientIDHeaderName name of the HTTP header that is used for reporting +// ClientIDHeaderName name of the HTTP header that is used for reporting. const ( ClientIDHeaderName = "Client-ID" ) @@ -28,6 +28,7 @@ func ClientID(next http.Handler) http.Handler { w.Header().Add(ClientIDHeaderName, claim.AuthorizedParty) } } + next.ServeHTTP(w, r) }) } @@ -41,5 +42,6 @@ func (c clientIDClaim) Valid() error { if c.AuthorizedParty == "" { return ErrEmptyAuthorizedParty } + return nil } diff --git a/http/middleware/response_header_test.go b/http/middleware/response_header_test.go index c5e897f9c..cf165a6dd 100644 --- a/http/middleware/response_header_test.go +++ b/http/middleware/response_header_test.go @@ -12,8 +12,8 @@ import ( ) const ( - emptyToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhenAiOiJjbGllbnRUZXN0In0.eAUlRLw2R2LEvI9TdaD9P6zGQyz-oF7V-Omm2x00iQk" + emptyToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" //nolint:gosec + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhenAiOiJjbGllbnRUZXN0In0.eAUlRLw2R2LEvI9TdaD9P6zGQyz-oF7V-Omm2x00iQk" //nolint:gosec ) func TestClientID(t *testing.T) { diff --git a/http/oauth2/authorizer.go b/http/oauth2/authorizer.go index d7058eeba..9b0267818 100644 --- a/http/oauth2/authorizer.go +++ b/http/oauth2/authorizer.go @@ -11,14 +11,14 @@ import ( ) // Authorizer is an implementation of security.Authorizer for OAuth2 -// it uses introspection to get user data and can check the scope +// it uses introspection to get user data and can check the scope. type Authorizer struct { introspection TokenIntrospecter scope Scope config *Config } -// Flow is a part of the OAuth2 config from the security schema +// Flow is a part of the OAuth2 config from the security schema. type Flow struct { AuthorizationURL string TokenURL string @@ -26,7 +26,7 @@ type Flow struct { Scopes map[string]string } -// Config contains the configuration from the api definition - currently not used +// Config contains the configuration from the api definition - currently not used. type Config struct { Description string Implicit *Flow @@ -36,22 +36,21 @@ type Config struct { } // NewAuthorizer creates an Authorizer for a specific TokenIntrospecter -// This Authorizer does not check the scope +// This Authorizer does not check the scope. func NewAuthorizer(introspector TokenIntrospecter, cfg *Config) *Authorizer { return &Authorizer{introspection: introspector, config: cfg} } -// WithScope returns a new Authorizer with the same TokenIntrospecter and the same Config that also checks the scope of a request +// WithScope returns a new Authorizer with the same TokenIntrospecter and the same Config that also checks the scope of a request. func (a *Authorizer) WithScope(tok string) *Authorizer { return &Authorizer{introspection: a.introspection, config: a.config, scope: Scope(tok)} } // Authorize authorizes a request with an introspection and validates the scope // Success: returns context with the introspection result and true -// Error: writes all errors directly to response, returns unchanged context and false +// Error: writes all errors directly to response, returns unchanged context and false. func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context.Context, bool) { ctx, ok := introspectRequest(r, w, a.introspection) - // Check if introspection was successful if !ok { return ctx, ok } @@ -60,6 +59,7 @@ func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context. // Check if the scope is valid for this user ok = validateScope(ctx, w, a.scope) } + return ctx, ok } @@ -68,10 +68,11 @@ func validateScope(ctx context.Context, w http.ResponseWriter, req Scope) bool { http.Error(w, fmt.Sprintf("Forbidden - requires scope %q", req), http.StatusForbidden) return false } + return true } -// CanAuthorizeRequest returns true, if the request contains a token in the configured header, otherwise false +// CanAuthorizeRequest returns true, if the request contains a token in the configured header, otherwise false. func (a *Authorizer) CanAuthorizeRequest(r *http.Request) bool { return security.GetBearerTokenFromHeader(r.Header.Get(oAuth2Header)) != "" } diff --git a/http/oauth2/example_multi_backend_test.go b/http/oauth2/example_multi_backend_test.go index 4637f7491..467574241 100644 --- a/http/oauth2/example_multi_backend_test.go +++ b/http/oauth2/example_multi_backend_test.go @@ -5,6 +5,7 @@ package oauth2_test import ( "context" "fmt" + "net/http" "net/http/httptest" "github.com/pace/bricks/http/oauth2" @@ -21,6 +22,7 @@ func (b multiAuthBackends) IntrospectToken(ctx context.Context, token string) (r return } } + return nil, oauth2.ErrInvalidToken } @@ -33,6 +35,7 @@ func (b *authBackend) IntrospectToken(ctx context.Context, token string) (*oauth Backend: b, }, nil } + return nil, oauth2.ErrInvalidToken } @@ -42,20 +45,23 @@ func Example_multipleBackends() { // authorized the request. The actual value used for the backend depends on // your implementation: you can use constants or pointers, like in this // example. - authorizer := oauth2.NewAuthorizer(multiAuthBackends{ &authBackend{"A", "token-a"}, &authBackend{"B", "token-b"}, &authBackend{"C", "token-c"}, }, nil) - r := httptest.NewRequest("GET", "/some/endpoint", nil) + r := httptest.NewRequest(http.MethodGet, "/some/endpoint", nil) r.Header.Set("Authorization", "Bearer token-b") if authorizer.CanAuthorizeRequest(r) { - ctx, ok := authorizer.Authorize(r, nil) + ctx, _ := authorizer.Authorize(r, nil) usedBackend, _ := oauth2.Backend(ctx) - fmt.Printf("%t %s", ok, usedBackend.(*authBackend)[0]) + + assertedBackend, ok := usedBackend.(*authBackend) + if ok { + fmt.Printf("%t %s", ok, assertedBackend[0]) + } } // Output: diff --git a/http/oauth2/introspection.go b/http/oauth2/introspection.go index 359364ed5..43dcd6555 100644 --- a/http/oauth2/introspection.go +++ b/http/oauth2/introspection.go @@ -7,22 +7,22 @@ import ( "errors" ) -// TokenIntrospecter needs to be implemented for token lookup +// TokenIntrospecter needs to be implemented for token lookup. type TokenIntrospecter interface { IntrospectToken(ctx context.Context, token string) (*IntrospectResponse, error) } -// ErrInvalidToken in case the token is not valid or expired +// ErrInvalidToken in case the token is not valid or expired. var ErrInvalidToken = errors.New("user token is invalid") -// ErrUpstreamConnection connection issue +// ErrUpstreamConnection connection issue. var ErrUpstreamConnection = errors.New("problem connecting to the introspection endpoint") -// ErrBadUpstreamResponse the response from the server has the wrong format +// ErrBadUpstreamResponse the response from the server has the wrong format. var ErrBadUpstreamResponse = errors.New("bad upstream response when introspecting token") // IntrospectResponse in case of a successful check of the -// oauth2 request +// oauth2 request. type IntrospectResponse struct { Active bool `json:"active"` Scope string `json:"scope"` diff --git a/http/oauth2/middleware/scopes_middleware.go b/http/oauth2/middleware/scopes_middleware.go index e612776fc..d8549e5d7 100644 --- a/http/oauth2/middleware/scopes_middleware.go +++ b/http/oauth2/middleware/scopes_middleware.go @@ -7,20 +7,21 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/pace/bricks/http/oauth2" ) -// RequiredScopes defines the scope each endpoint requires +// RequiredScopes defines the scope each endpoint requires. type RequiredScopes map[string]oauth2.Scope -// Deprecated: ScopesMiddleware contains required scopes for each endpoint - For generated APIs use the generated -// AuthenticationBackend with oauth2.Authorizer and set a Scope +// ScopesMiddleware contains required scopes for each endpoint. +// Deprecated: For generated APIs use the generated AuthenticationBackend with oauth2.Authorizer and set a Scope. type ScopesMiddleware struct { RequiredScopes RequiredScopes } -// Deprecated: NewScopesMiddleware return a new scopes middleware - For generated APIs use the generated -// AuthenticationBackend with oauth2.Authorizer and set a scope +// NewScopesMiddleware return a new scopes middleware. +// Deprecated: For generated APIs use the generated AuthenticationBackend with oauth2.Authorizer and set a scope. func NewScopesMiddleware(scopes RequiredScopes) *ScopesMiddleware { return &ScopesMiddleware{RequiredScopes: scopes} } @@ -34,6 +35,7 @@ func (m *ScopesMiddleware) Handler(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } + http.Error(w, fmt.Sprintf("Forbidden - requires scope %q", m.RequiredScopes[routeName]), http.StatusForbidden) }) } diff --git a/http/oauth2/middleware/scopes_middleware_test.go b/http/oauth2/middleware/scopes_middleware_test.go index 7089f53a4..615a339ef 100644 --- a/http/oauth2/middleware/scopes_middleware_test.go +++ b/http/oauth2/middleware/scopes_middleware_test.go @@ -11,6 +11,8 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/pace/bricks/http/oauth2" ) @@ -23,12 +25,17 @@ func TestScopesMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } - if got, ex := resp.StatusCode, 200; got != ex { + if got, ex := resp.StatusCode, http.StatusOK; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } @@ -45,7 +52,12 @@ func TestScopesMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } @@ -64,21 +76,23 @@ func setupRouter(requiredScope string, tokenScope string) *mux.Router { rs := RequiredScopes{ "GetFoo": oauth2.Scope(requiredScope), } - m := NewScopesMiddleware(rs) // nolint: staticcheck - om := oauth2.NewMiddleware(&tokenIntrospecter{returnedScope: tokenScope}) // nolint: staticcheck + m := NewScopesMiddleware(rs) + om := oauth2.NewMiddleware(&tokenIntrospecter{returnedScope: tokenScope}) //nolint:staticcheck r := mux.NewRouter() r.Use(om.Handler) r.Use(m.Handler) r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello") + if _, err := fmt.Fprint(w, "Hello"); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } }).Name("GetFoo") return r } func setupRequest() *http.Request { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer some-token") return req diff --git a/http/oauth2/oauth2.go b/http/oauth2/oauth2.go index 68c0f3a18..63695a5da 100644 --- a/http/oauth2/oauth2.go +++ b/http/oauth2/oauth2.go @@ -16,14 +16,14 @@ import ( "github.com/pace/bricks/maintenance/log" ) -// Deprecated: Middleware holds data necessary for Oauth processing - Deprecated for generated apis, -// use the generated Authentication Backend of the API with oauth2.Authorizer +// Middleware holds data necessary for Oauth processing. +// Deprecated: for generated apis, use the generated Authentication Backend of the API with oauth2.Authorizer. type Middleware struct { Backend TokenIntrospecter } -// Deprecated: NewMiddleware creates a new Oauth middleware - Deprecated for generated apis, -// use the generated AuthenticationBackend of the API with oauth2.Authorizer +// NewMiddleware creates a new Oauth middleware. +// Deprecated: for generated apis, use the generated AuthenticationBackend of the API with oauth2.Authorizer. func NewMiddleware(backend TokenIntrospecter) *Middleware { return &Middleware{Backend: backend} } @@ -36,6 +36,7 @@ func (m *Middleware) Handler(next http.Handler) http.Handler { if !isOk { return } + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -51,7 +52,7 @@ type token struct { const oAuth2Header = "Authorization" -// GetValue returns the oauth2 token of the current user +// GetValue returns the oauth2 token of the current user. func (t *token) GetValue() string { return t.value } @@ -60,7 +61,7 @@ func (t *token) GetValue() string { // Success: it returns a context containing the introspection result and true // if the introspection was successful // Error: The function writes the error in the Response and creates a log-message -// with more details and returns nil and false if any error occurs during the introspection +// with more details and returns nil and false if any error occurs during the introspection. func introspectRequest(r *http.Request, w http.ResponseWriter, tokenIntro TokenIntrospecter) (context.Context, bool) { // Setup tracing span := sentry.StartSpan(r.Context(), "function", sentry.WithDescription("introspectRequest")) @@ -85,9 +86,10 @@ func introspectRequest(r *http.Request, w http.ResponseWriter, tokenIntro TokenI http.Error(w, err.Error(), http.StatusUnauthorized) default: http.Error(w, err.Error(), http.StatusInternalServerError) - } + log.Req(r).Info().Msg(err.Error()) + return nil, false } @@ -115,10 +117,11 @@ func fromIntrospectResponse(s *IntrospectResponse, tokenValue string) token { } t.scope = Scope(s.Scope) + return t } -// Request adds Authorization token to r +// Request adds Authorization token to r. func Request(r *http.Request) *http.Request { tok, ok := security.GetTokenFromContext(r.Context()) if ok { @@ -132,60 +135,73 @@ func Request(r *http.Request) *http.Request { // the permissions represented by the provided scope are included in the valid scope. func HasScope(ctx context.Context, scope Scope) bool { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return false } + return scope.IsIncludedIn(oauth2token.scope) } -// UserID returns the userID stored in ctx +// UserID returns the userID stored in ctx. func UserID(ctx context.Context) (string, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return "", false } + return oauth2token.userID, true } -// AuthTime returns the auth time stored in ctx as unix timestamp +// AuthTime returns the auth time stored in ctx as unix timestamp. func AuthTime(ctx context.Context) (int64, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return 0, false } + return oauth2token.authTime, true } -// Scopes returns the scopes stored in ctx +// Scopes returns the scopes stored in ctx. func Scopes(ctx context.Context) []string { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return []string{} } + return oauth2token.scope.toSlice() } func AddScope(ctx context.Context, scope string) context.Context { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return ctx } + oauth2token.scope = oauth2token.scope.Add(scope) + return security.ContextWithToken(ctx, oauth2token) } -// ClientID returns the clientID stored in ctx +// ClientID returns the clientID stored in ctx. func ClientID(ctx context.Context) (string, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return "", false } + return oauth2token.clientID, true } @@ -193,15 +209,17 @@ func ClientID(ctx context.Context) (string, bool) { // authorization backend for the token. func Backend(ctx context.Context) (interface{}, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return nil, false } + return oauth2token.backend, true } // ContextTransfer sources the oauth2 token from the sourceCtx -// and returning a new context based on the targetCtx +// and returning a new context based on the targetCtx. func ContextTransfer(sourceCtx context.Context, targetCtx context.Context) context.Context { tok, _ := security.GetTokenFromContext(sourceCtx) return security.ContextWithToken(targetCtx, tok) @@ -209,11 +227,12 @@ func ContextTransfer(sourceCtx context.Context, targetCtx context.Context) conte // Deprecated: BearerToken was moved to the security package, // because it's used by apiKey and oauth2 authorization. -// BearerToken returns the bearer token stored in ctx +// BearerToken returns the bearer token stored in ctx. func BearerToken(ctx context.Context) (string, bool) { if tok, ok := security.GetTokenFromContext(ctx); ok { return tok.GetValue(), true } + return "", false } diff --git a/http/oauth2/oauth2_test.go b/http/oauth2/oauth2_test.go index c20e09e19..0c3ac8d6d 100644 --- a/http/oauth2/oauth2_test.go +++ b/http/oauth2/oauth2_test.go @@ -10,8 +10,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pace/bricks/http/security" "github.com/pace/bricks/maintenance/log" @@ -58,7 +61,7 @@ func TestHandlerIntrospectErrorAsMiddleware(t *testing.T) { r.Use(m.Handler) r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer some-token") w := httptest.NewRecorder() @@ -66,7 +69,12 @@ func TestHandlerIntrospectErrorAsMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } @@ -96,64 +104,77 @@ func TestAuthenticatorWithSuccess(t *testing.T) { userScopes string expectedScopes string active bool - clientId string - userId string + clientID string + userID string }{ { desc: "Tests a valid Request with OAuth2 Authentication without Scope checking", active: true, userScopes: "ABC DHHG kjdk", - clientId: "ClientId", - userId: "UserId", + clientID: "ClientId", + userID: "UserId", }, { desc: "Tests a valid Request with OAuth2 Authentication and one scope to check", active: true, userScopes: "ABC DHHG kjdk", - clientId: "ClientId", - userId: "UserId", + clientID: "ClientId", + userID: "UserId", expectedScopes: "ABC", }, { desc: "Tests a valid Request with OAuth2 Authentication and two scope to check", active: true, userScopes: "ABC DHHG kjdk", - clientId: "ClientId", - userId: "UserId", + clientID: "ClientId", + userID: "UserId", expectedScopes: "ABC kjdk", }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") auth := NewAuthorizer(&tokenIntrospectedSuccessful{&IntrospectResponse{ Active: tC.active, Scope: tC.userScopes, - ClientID: tC.clientId, - UserID: tC.userId, + ClientID: tC.clientID, + UserID: tC.userID, }}, &Config{}) if tC.expectedScopes != "" { auth = auth.WithScope(tC.expectedScopes) } + authorize, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + if !b || authorize == nil { t.Errorf("Expected succesfull Authentication, but was not succesfull with code %d and body %q", resp.StatusCode, string(body)) return } + to, _ := security.GetTokenFromContext(authorize) + tok, ok := to.(*token) + if !ok || tok.value != "bearer" || tok.scope != Scope(tC.userScopes) || tok.clientID != tC.clientID || tok.userID != tC.userID { + require.IsType(t, auth.introspection, &tokenIntrospectedSuccessful{}) - if !ok || tok.value != "bearer" || tok.scope != Scope(tC.userScopes) || tok.clientID != tC.clientId || tok.userID != tC.userId { - t.Errorf("Expected %v but got %v", auth.introspection.(*tokenIntrospectedSuccessful).response, tok) + tis, ok := auth.introspection.(*tokenIntrospectedSuccessful) + if ok { + t.Errorf("Expected %v but got %v", tis.response, tok) + } } }) } @@ -168,23 +189,31 @@ func TestAuthenticationSuccessScopeError(t *testing.T) { }}, &Config{}).WithScope("DE") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusForbidden; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "Forbidden - requires scope \"DE\"\n"; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } @@ -226,18 +255,25 @@ func TestAuthenticationWithErrors(t *testing.T) { t.Run(tC.desc, func(t *testing.T) { auth := NewAuthorizer(&tokenInspectorWithError{returnedErr: tC.returnedErr}, &Config{}) w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") + _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + if b { - t.Errorf("Expected error in authentication, but was succesful with code %d and body %v", resp.StatusCode, string(body)) + t.Errorf("Expected error in authentication, but was successful with code %d and body %v", resp.StatusCode, string(body)) } if got, ex := w.Code, tC.expectedCode; got != ex { @@ -266,8 +302,10 @@ func Example() { if err != nil { panic(err) } + return } + _, err := fmt.Fprintf(w, "Your client may not have the right scopes to see the secret code") if err != nil { panic(err) @@ -275,8 +313,9 @@ func Example() { }) srv := &http.Server{ - Handler: r, - Addr: "127.0.0.1:8000", + Handler: r, + Addr: "127.0.0.1:8000", + ReadHeaderTimeout: 30 * time.Second, } log.Fatal(srv.ListenAndServe()) @@ -290,7 +329,7 @@ func TestRequest(t *testing.T) { scope: Scope("scope1 scope2"), } - r := httptest.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com", nil) ctx := security.ContextWithToken(r.Context(), &to) r = r.WithContext(ctx) @@ -303,7 +342,7 @@ func TestRequest(t *testing.T) { } func TestRequestWithNoToken(t *testing.T) { - r := httptest.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com", nil) r2 := Request(r) header := r2.Header.Get("Authorization") @@ -402,6 +441,7 @@ func TestUnsuccessfulAccessors(t *testing.T) { func TestWithBearerToken(t *testing.T) { ctx := context.Background() ctx = WithBearerToken(ctx, "some access token") + token, ok := security.GetTokenFromContext(ctx) if !ok || token.GetValue() != "some access token" { t.Error("could not store bearer token in context") @@ -419,14 +459,17 @@ func TestAddScope(t *testing.T) { wantCtx := context.Background() wantCtx = WithBearerToken(wantCtx, "some access token") + tok, ok := security.GetTokenFromContext(wantCtx) if !ok { t.Error("could not get token from context") } + ouathToken, ok := tok.(*token) if !ok { t.Error("could not convert token to oauth token") } + ouathToken.scope = "scope1" tests := []struct { @@ -446,14 +489,17 @@ func TestAddScope(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := AddScope(tt.args.ctx, tt.args.scope) + gotTok, ok := security.GetTokenFromContext(got) if !ok { t.Error("could not get token from context") } + gotOauthToken, ok := gotTok.(*token) if !ok { t.Error("could not convert token to oauth token") } + if gotOauthToken.scope != tt.want.scope { t.Errorf("AddScope() = %v, want %v", gotOauthToken.scope, tt.want.scope) } diff --git a/http/oauth2/scope.go b/http/oauth2/scope.go index 07c2a5d8f..39ca11369 100644 --- a/http/oauth2/scope.go +++ b/http/oauth2/scope.go @@ -6,7 +6,7 @@ import ( "strings" ) -// Scope represents an OAuth 2 access token scope +// Scope represents an OAuth 2 access token scope. type Scope string // IsIncludedIn checks if the permissions of a scope s are also included diff --git a/http/oauth2/scope_test.go b/http/oauth2/scope_test.go index 01d95936c..bc0181657 100644 --- a/http/oauth2/scope_test.go +++ b/http/oauth2/scope_test.go @@ -35,6 +35,7 @@ func TestScope_Add(t *testing.T) { type args struct { scope string } + tests := []struct { name string s Scope diff --git a/http/oidc/config.go b/http/oidc/config.go index b97357187..33a3a6b53 100644 --- a/http/oidc/config.go +++ b/http/oidc/config.go @@ -2,8 +2,8 @@ package oidc -// Config for OIDC based on swagger +// Config for OIDC based on swagger. type Config struct { Description string - OpenIdConnectURL string `json:"openIdConnectUrl"` + OpenIDConnectURL string `json:"openIdConnectUrl"` } diff --git a/http/router.go b/http/router.go index 2531c5b1f..3ded0f6eb 100755 --- a/http/router.go +++ b/http/router.go @@ -6,9 +6,8 @@ import ( "net/http" "net/http/pprof" - "github.com/pace/bricks/maintenance/tracing" - "github.com/gorilla/mux" + "github.com/pace/bricks/http/middleware" "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/errors" @@ -16,11 +15,12 @@ import ( "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/maintenance/metric" + "github.com/pace/bricks/maintenance/tracing" redactMdw "github.com/pace/bricks/pkg/redact/middleware" ) // Router returns the default microservice endpoints for -// health, metrics and debugging +// health, metrics and debugging. func Router() *mux.Router { r := mux.NewRouter() @@ -53,10 +53,10 @@ func Router() *mux.Router { // report Client ID back to caller r.Use(middleware.ClientID) - // support redacting of data accross the full request scope + // support redacting of data across the full request scope r.Use(redactMdw.Redact) - // makes some infos about the request accessable from the context + // makes some infos about the request accessible from the context r.Use(middleware.RequestInContext) // for prometheus diff --git a/http/router_test.go b/http/router_test.go index bf43f761b..3a05459ce 100644 --- a/http/router_test.go +++ b/http/router_test.go @@ -11,19 +11,27 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/pace/bricks/http/jsonapi/runtime" "github.com/pace/bricks/maintenance/health" - "github.com/stretchr/testify/require" ) func TestHealthHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/liveness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/liveness", nil) Router().ServeHTTP(rec, req) resp := rec.Result() - require.Equal(t, 200, resp.StatusCode) + + { + err := resp.Body.Close() + assert.NoError(t, err) + } + + require.Equal(t, http.StatusOK, resp.StatusCode) data, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -52,19 +60,27 @@ func TestHealthRoutes(t *testing.T) { expectedResult: "OK\n", title: "route liveness", }} + health.SetCustomReadinessCheck(func(w http.ResponseWriter, r *http.Request) { _, err := fmt.Fprint(w, "Ready") require.NoError(t, err) }) + for _, tC := range tCs { t.Run(tC.title, func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", tC.route, nil) + req := httptest.NewRequest(http.MethodGet, tC.route, nil) Router().ServeHTTP(rec, req) resp := rec.Result() data, err := io.ReadAll(resp.Body) + + { + err := resp.Body.Close() + assert.NoError(t, err) + } + require.NoError(t, err) require.Equal(t, tC.expectedResult, string(data)) }) @@ -73,13 +89,13 @@ func TestHealthRoutes(t *testing.T) { func TestCustomRoutes(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) // example of a service foo exposing api bar fooRouter := mux.NewRouter() fooRouter.HandleFunc("/foo/bar", func(w http.ResponseWriter, r *http.Request) { runtime.WriteError(w, http.StatusNotImplemented, fmt.Errorf("Some error")) - }).Methods("GET") + }).Methods(http.MethodGet) r := Router() // service routers will be mounted like this @@ -89,6 +105,11 @@ func TestCustomRoutes(t *testing.T) { resp := rec.Result() + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + require.Equal(t, 501, resp.StatusCode, "Expected /foo/bar to respond with 501") var e struct { diff --git a/http/security/apikey/authorizer.go b/http/security/apikey/authorizer.go index 72e9619f1..5a79a0cab 100644 --- a/http/security/apikey/authorizer.go +++ b/http/security/apikey/authorizer.go @@ -30,7 +30,7 @@ type token struct { value string } -// GetValue returns the api key +// GetValue returns the api key. func (b *token) GetValue() string { return b.value } @@ -42,22 +42,26 @@ func NewAuthorizer(authConfig *Config, apiKey string) *Authorizer { // Authorize authorizes a request based on the configured api key the config of the security schema // Success: A context with a token containing the api key and true -// Error: the unchanged request context and false. the response already contains the error message +// Error: the unchanged request context and false. the response already contains the error message. func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context.Context, bool) { key := security.GetBearerTokenFromHeader(r.Header.Get(a.authConfig.Name)) if key == "" { log.Req(r).Info().Msg("No Api Key present in field " + a.authConfig.Name) http.Error(w, "Unauthorized", http.StatusUnauthorized) + return r.Context(), false } + if key == a.apiKey { return security.ContextWithToken(r.Context(), &token{key}), true } + http.Error(w, "ApiKey not valid", http.StatusUnauthorized) + return r.Context(), false } -// CanAuthorizeRequest returns true, if the request contains a token in the configured header, otherwise false +// CanAuthorizeRequest returns true, if the request contains a token in the configured header, otherwise false. func (a *Authorizer) CanAuthorizeRequest(r http.Request) bool { return security.GetBearerTokenFromHeader(r.Header.Get(a.authConfig.Name)) != "" } diff --git a/http/security/apikey/authorizer_test.go b/http/security/apikey/authorizer_test.go index 781141aff..0d6985117 100644 --- a/http/security/apikey/authorizer_test.go +++ b/http/security/apikey/authorizer_test.go @@ -7,29 +7,37 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" ) func TestApiKeyAuthenticationSuccessful(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer testkey") _, b := auth.Authorize(r, w) resp := w.Result() + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) - resp.Body.Close() if err != nil { t.Fatal(err) } + if !b { t.Errorf("Expected no error in authentication, but failed with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusOK; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), ""; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } @@ -39,23 +47,29 @@ func TestApiKeyAuthenticationError(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer wrongKey") _, b := auth.Authorize(r, w) resp := w.Result() + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) - resp.Body.Close() if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusUnauthorized; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "ApiKey not valid\n"; got != ex { t.Errorf("Expected error massage %q, got %q", ex, got) } @@ -65,22 +79,30 @@ func TestApiKeyAuthenticationNoKey(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) - resp.Body.Close() + + defer func() { + err = resp.Body.Close() + assert.NoError(t, err) + }() + if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusUnauthorized; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "Unauthorized\n"; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } diff --git a/http/security/authorizer.go b/http/security/authorizer.go index 0377ae35c..169daf7db 100644 --- a/http/security/authorizer.go +++ b/http/security/authorizer.go @@ -8,7 +8,7 @@ import ( ) // Authorizer describes the needed functions for authorization, -// already implemented in oauth2.Authorizer and apikey.Authorizer +// already implemented in oauth2.Authorizer and apikey.Authorizer. type Authorizer interface { // Authorize should authorize a request. // Success: returns a context with information of the authorization @@ -18,7 +18,7 @@ type Authorizer interface { } // CanAuthorize offers a method to check if an -// authorizer can authorize a request +// authorizer can authorize a request. type CanAuthorize interface { // CanAuthorizeRequest should check if a request contains the needed information to be authorized CanAuthorizeRequest(r http.Request) bool diff --git a/http/security/helper.go b/http/security/helper.go index 1528aeab3..e82efa981 100644 --- a/http/security/helper.go +++ b/http/security/helper.go @@ -20,7 +20,7 @@ func (ts TokenString) GetValue() string { return string(ts) } -// prefix of the Authorization header +// prefix of the Authorization header. const headerPrefix = "Bearer " var tokenKey = ctx("Token") @@ -34,10 +34,11 @@ func GetBearerTokenFromHeader(authHeader string) string { if !hasPrefix { return "" } + return strings.TrimPrefix(authHeader, headerPrefix) } -// ContextWithToken creates a new Context with the token +// ContextWithToken creates a new Context with the token. func ContextWithToken(targetCtx context.Context, token Token) context.Context { return context.WithValue(targetCtx, tokenKey, token) } @@ -48,11 +49,13 @@ func GetTokenFromContext(ctx context.Context) (Token, bool) { if val == nil { return nil, false } + tok, ok := val.(Token) + return tok, ok } -// GetAuthHeader creates the valid value for the authentication header +// GetAuthHeader creates the valid value for the authentication header. func GetAuthHeader(tok Token) string { return headerPrefix + tok.GetValue() } diff --git a/http/server.go b/http/server.go index e73f673bc..31eb16f1f 100644 --- a/http/server.go +++ b/http/server.go @@ -9,6 +9,7 @@ import ( "time" "github.com/caarlos0/env/v10" + "github.com/pace/bricks/maintenance/log" ) @@ -26,19 +27,19 @@ type config struct { WriteTimeout time.Duration `env:"WRITE_TIMEOUT" envDefault:"60s"` } -// addrOrPort returns ADDR if it is defined, otherwise PORT is used +// addrOrPort returns ADDR if it is defined, otherwise PORT is used. func (cfg config) addrOrPort() string { if cfg.Addr != "" { return cfg.Addr } + return ":" + strconv.Itoa(cfg.Port) } var cfg config func parseConfig() { - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse server environment: %v", err) } } @@ -57,7 +58,7 @@ func Server(handler http.Handler) *http.Server { } } -// Environment returns the name of the current server environment +// Environment returns the name of the current server environment. func Environment() string { return cfg.Environment } diff --git a/http/server_test.go b/http/server_test.go index 9d682df7f..5edab5bf2 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -6,18 +6,34 @@ import ( "os" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestServer(t *testing.T) { // Defaults - os.Setenv("ADDR", "") - os.Setenv("PORT", "") - os.Setenv("MAX_HEADER_BYTES", "") - os.Setenv("IDLE_TIMEOUT", "") - os.Setenv("READ_TIMEOUT", "") - os.Setenv("WRITE_TIMEOUT", "") + err := os.Setenv("ADDR", "") + require.NoError(t, err) + + err = os.Setenv("PORT", "") + require.NoError(t, err) + + err = os.Setenv("MAX_HEADER_BYTES", "") + require.NoError(t, err) + + err = os.Setenv("IDLE_TIMEOUT", "") + require.NoError(t, err) + + err = os.Setenv("READ_TIMEOUT", "") + require.NoError(t, err) + + err = os.Setenv("WRITE_TIMEOUT", "") + require.NoError(t, err) + parseConfig() + s := Server(nil) + cases := []struct { env string expected, actual interface{} @@ -35,14 +51,28 @@ func TestServer(t *testing.T) { } // custom - os.Setenv("ADDR", ":5432") - os.Setenv("PORT", "1234") - os.Setenv("MAX_HEADER_BYTES", "100") - os.Setenv("IDLE_TIMEOUT", "1s") - os.Setenv("READ_TIMEOUT", "2s") - os.Setenv("WRITE_TIMEOUT", "3s") + err = os.Setenv("ADDR", ":5432") + require.NoError(t, err) + + err = os.Setenv("PORT", "1234") + require.NoError(t, err) + + err = os.Setenv("MAX_HEADER_BYTES", "100") + require.NoError(t, err) + + err = os.Setenv("IDLE_TIMEOUT", "1s") + require.NoError(t, err) + + err = os.Setenv("READ_TIMEOUT", "2s") + require.NoError(t, err) + + err = os.Setenv("WRITE_TIMEOUT", "3s") + require.NoError(t, err) + parseConfig() + s = Server(nil) + cases = []struct { env string expected, actual interface{} @@ -62,15 +92,21 @@ func TestServer(t *testing.T) { func TestEnvironment(t *testing.T) { // Defaults - os.Setenv("ENVIRONMENT", "") + err := os.Setenv("ENVIRONMENT", "") + require.NoError(t, err) + parseConfig() + if Environment() != "edge" { t.Errorf("Expected edge, got: %q", Environment()) } // custom - os.Setenv("ENVIRONMENT", "production") + err = os.Setenv("ENVIRONMENT", "production") + require.NoError(t, err) + parseConfig() + if Environment() != "production" { t.Errorf("Expected production, got: %q", Environment()) } diff --git a/http/transport/attempt_round_tripper.go b/http/transport/attempt_round_tripper.go index 77a78cafa..acb59f2ee 100644 --- a/http/transport/attempt_round_tripper.go +++ b/http/transport/attempt_round_tripper.go @@ -49,11 +49,13 @@ func attemptFromCtx(ctx context.Context) int32 { if !ok { return 0 } + return a } func transportWithAttempt(rt http.RoundTripper) http.RoundTripper { ar := &attemptRoundTripper{attempt: 0} ar.SetTransport(rt) + return ar } diff --git a/http/transport/chainable.go b/http/transport/chainable.go index 970745b95..ad4f636b6 100644 --- a/http/transport/chainable.go +++ b/http/transport/chainable.go @@ -4,7 +4,7 @@ package transport import "net/http" -// ChainableRoundTripper models a chainable round tripper +// ChainableRoundTripper models a chainable round tripper. type ChainableRoundTripper interface { http.RoundTripper @@ -41,7 +41,7 @@ type RoundTripperChain struct { } // Chain returns a round tripper chain with the specified chainable round trippers and http.DefaultTransport as transport. -// The transport can be overriden by using the Final method. +// The transport can be overridden by using the Final method. func Chain(rt ...ChainableRoundTripper) *RoundTripperChain { final := &finalRoundTripper{transport: http.DefaultTransport} c := &RoundTripperChain{first: final, current: final, final: final} @@ -67,6 +67,7 @@ func (c *RoundTripperChain) Use(rt ChainableRoundTripper) *RoundTripperChain { c.current.SetTransport(rt) rt.SetTransport(c.final) + c.current = rt return c diff --git a/http/transport/chainable_test.go b/http/transport/chainable_test.go index 924c2eff1..e85bffd64 100644 --- a/http/transport/chainable_test.go +++ b/http/transport/chainable_test.go @@ -14,7 +14,7 @@ import ( // TestRoundTripperRace will detect race conditions // in any RoundTripper by sending concurrent requests. // Make sure to use the -race parameter when -// executing this test +// executing this test. func TestRoundTripperRace(t *testing.T) { client := http.Client{ Transport: NewDefaultTransportChain(), @@ -33,12 +33,18 @@ func TestRoundTripperRace(t *testing.T) { go func() { for i := 0; i < 10; i++ { - client.Get(server.URL + "/test001") // nolint: errcheck + resp, err := client.Get(server.URL + "/test001") + if err == nil { + _ = resp.Body.Close() + } } }() for i := 0; i < 10; i++ { - client.Get(server.URL + "/test002") // nolint: errcheck + resp, err := client.Get(server.URL + "/test002") + if err == nil { + _ = resp.Body.Close() + } } } @@ -48,16 +54,17 @@ func TestRoundTripperChaining(t *testing.T) { c := Chain().Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) - _, err := c.RoundTrip(req) + _, err := c.RoundTrip(req) //nolint:bodyclose if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %q, got %q", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %q, got %q", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %q, got %q", url, v) } @@ -68,19 +75,21 @@ func TestRoundTripperChaining(t *testing.T) { c.Use(&addHeaderRoundTripper{key: "foo", value: "bar"}).Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) - _, err := c.RoundTrip(req) + _, err := c.RoundTrip(req) //nolint:bodyclose if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %v, got %v", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %v, got %v", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %v, got %v", url, v) } + if v, ex := transport.req.Header.Get("foo"), "bar"; v != ex { t.Errorf("Expected header foo to eq %v, got %v", ex, v) } @@ -93,22 +102,25 @@ func TestRoundTripperChaining(t *testing.T) { c := Chain(rt1, rt2, rt3).Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) - _, err := c.RoundTrip(req) + _, err := c.RoundTrip(req) //nolint:bodyclose if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %v, got %v", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %v, got %v", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %v, got %v", url, v) } + if v, ex := transport.req.Header.Get("foo"), "baroverride"; v != ex { t.Errorf("Expected header foo to eq %v, got %v", ex, v) } + if v, ex := transport.req.Header.Get("Authorization"), "Bearer 123"; v != ex { t.Errorf("Expected header Authorization to eq %v, got %v", ex, v) } diff --git a/http/transport/circuit_breaker_tripper.go b/http/transport/circuit_breaker_tripper.go index 25b8a50e4..d24f0724f 100644 --- a/http/transport/circuit_breaker_tripper.go +++ b/http/transport/circuit_breaker_tripper.go @@ -48,6 +48,7 @@ func NewCircuitBreakerTripper(settings gobreaker.Settings) *circuitBreakerTrippe }, []string{"from", "to"}) var ok bool + var are prometheus.AlreadyRegisteredError if err := prometheus.Register(stateSwitchCounterVec); errors.As(err, &are) { stateSwitchCounterVec, ok = are.ExistingCollector.(*prometheus.CounterVec) @@ -71,20 +72,20 @@ func NewCircuitBreakerTripper(settings gobreaker.Settings) *circuitBreakerTrippe return &circuitBreakerTripper{breaker: gobreaker.NewCircuitBreaker(settings)} } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (c *circuitBreakerTripper) Transport() http.RoundTripper { return c.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (c *circuitBreakerTripper) SetTransport(rt http.RoundTripper) { c.transport = rt } -// RoundTrip executes a single HTTP transaction via Transport() +// RoundTrip executes a single HTTP transaction via Transport(). func (c *circuitBreakerTripper) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := c.breaker.Execute(func() (interface{}, error) { - return c.transport.RoundTrip(req) + return c.transport.RoundTrip(req) //nolint:bodyclose }) if err != nil { switch { @@ -96,5 +97,10 @@ func (c *circuitBreakerTripper) RoundTrip(req *http.Request) (*http.Response, er } } - return resp.(*http.Response), nil + out, ok := resp.(*http.Response) + if !ok { + return nil, fmt.Errorf("unexpected response type: %T", resp) + } + + return out, nil } diff --git a/http/transport/circuit_breaker_tripper_test.go b/http/transport/circuit_breaker_tripper_test.go index 4272f3f7a..1db68457f 100644 --- a/http/transport/circuit_breaker_tripper_test.go +++ b/http/transport/circuit_breaker_tripper_test.go @@ -13,21 +13,19 @@ import ( ) func TestCircuitBreakerTripper(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) t.Run("with_default_settings", func(t *testing.T) { breaker := NewDefaultCircuitBreakerTripper("testcircuitbreaker") chain := Chain(breaker).Final(&failingRoundTripper{}) for i := 0; i < 6; i++ { - if _, err := chain.RoundTrip(req); errors.Is(err, ErrCircuitBroken) { - t.Errorf("got err=%q, before expected", ErrCircuitBroken) - } + _, err := chain.RoundTrip(req) //nolint:bodyclose + require.NotErrorIs(t, err, ErrCircuitBroken) } - if _, err := chain.RoundTrip(req); !errors.Is(err, ErrCircuitBroken) { - t.Errorf("wanted err=%q, got err=%q", ErrCircuitBroken, err) - } + _, err := chain.RoundTrip(req) //nolint:bodyclose + require.ErrorIs(t, err, ErrCircuitBroken) }) t.Run("panic_on_empty_name", func(t *testing.T) { @@ -48,6 +46,11 @@ func TestCircuitBreakerTripper(t *testing.T) { resp, err := chain.RoundTrip(req) require.NoError(t, err, "expected no err, got err=%q", err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + gotBodyStr, err := io.ReadAll(resp.Body) require.NoError(t, err, "failed reading response body no err, got err=%q", err) diff --git a/http/transport/default_transport.go b/http/transport/default_transport.go index 8a1bd52e2..7850b8f4b 100644 --- a/http/transport/default_transport.go +++ b/http/transport/default_transport.go @@ -16,9 +16,9 @@ func NewDefaultTransportChain() *RoundTripperChain { ) } -// NewDefaultTransportChain returns a transport chain with retry, jaeger and logging support. +// NewDefaultTransportChainWithExternalName returns a transport chain with retry, jaeger and logging support. // If not explicitly finalized via `Final` it uses `http.DefaultTransport` as finalizer. -// The passed name is recorded as external dependency +// The passed name is recorded as external dependency. func NewDefaultTransportChainWithExternalName(name string) *RoundTripperChain { return Chain( &ExternalDependencyRoundTripper{name: name}, diff --git a/http/transport/default_transport_test.go b/http/transport/default_transport_test.go index 78e656c0b..048b27087 100644 --- a/http/transport/default_transport_test.go +++ b/http/transport/default_transport_test.go @@ -12,14 +12,22 @@ import ( "os" "testing" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestNewDefaultTransportChain(t *testing.T) { old := os.Getenv("HTTP_TRANSPORT_DUMP") - defer os.Setenv("HTTP_TRANSPORT_DUMP", old) - os.Setenv("HTTP_TRANSPORT_DUMP", "request,response,body") + + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP", old) + require.NoError(t, err) + }() + + err := os.Setenv("HTTP_TRANSPORT_DUMP", "request,response,body") + require.NoError(t, err) t.Run("Finalizer not set explicitly", func(t *testing.T) { b := "Hello World" @@ -29,19 +37,32 @@ func TestNewDefaultTransportChain(t *testing.T) { retry++ if retry == 5 { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, b) + + _, err := fmt.Fprint(w, b) + require.NoError(t, err) + return } + w.WriteHeader(http.StatusBadGateway) - fmt.Fprint(w, b) + + _, err := fmt.Fprint(w, b) + require.NoError(t, err) })) - req := httptest.NewRequest("GET", ts.URL, nil) + req := httptest.NewRequest(http.MethodGet, ts.URL, nil) req = req.WithContext(log.WithContext(context.Background())) + resp, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + ts.Close() assert.Equal(t, retry, 5) @@ -60,17 +81,24 @@ func TestNewDefaultTransportChain(t *testing.T) { tr := &transportWithBody{body: "abc"} dt := NewDefaultTransportChain().Final(tr) - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req = req.WithContext(log.WithContext(context.Background())) + resp, err := dt.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Expected readable body, got error: %q", err.Error()) } + if ex, got := tr.body, string(body); ex != got { t.Errorf("Expected body %q, got %q", ex, got) } @@ -84,7 +112,7 @@ type transportWithBody struct { func (t *transportWithBody) RoundTrip(req *http.Request) (*http.Response, error) { body := io.NopCloser(bytes.NewReader([]byte(t.body))) - resp := &http.Response{Body: body, StatusCode: 200} + resp := &http.Response{Body: body, StatusCode: http.StatusOK} return resp, nil } diff --git a/http/transport/dump_options.go b/http/transport/dump_options.go index 2e37d1f97..ad25981be 100644 --- a/http/transport/dump_options.go +++ b/http/transport/dump_options.go @@ -8,11 +8,11 @@ import ( func NewDumpOptions(opts ...DumpOption) (DumpOptions, error) { dumpOptions := DumpOptions(map[string]bool{}) for _, opt := range opts { - err := opt(dumpOptions) - if err != nil { + if err := opt(dumpOptions); err != nil { return nil, err } } + return dumpOptions, nil } @@ -29,6 +29,7 @@ func (o DumpOptions) AnyEnabled(options ...string) bool { return true } } + return false } @@ -39,7 +40,9 @@ func WithDumpOption(option string, enabled bool) DumpOption { if !isDumpOptionValid(option) { return fmt.Errorf("invalid dump option %q", option) } + o[option] = enabled + return nil } } @@ -80,8 +83,10 @@ func mergeDumpOptions(globalOptions, reqOptions DumpOptions) DumpOptions { // req option already exists, ignore the global one continue } + reqOptions[globalKey] = globalVal } + return reqOptions } @@ -91,11 +96,13 @@ func CtxWithDumpRoundTripperOptions(ctx context.Context, opts DumpOptions) conte if opts == nil { return ctx } + return context.WithValue(ctx, dumpRoundTripperCtxKey{}, opts) } func DumpRoundTripperOptionsFromCtx(ctx context.Context) DumpOptions { do := ctx.Value(dumpRoundTripperCtxKey{}) dumpOptions, _ := do.(DumpOptions) + return dumpOptions } diff --git a/http/transport/dump_round_tripper.go b/http/transport/dump_round_tripper.go index 709408b96..f8f5f0b56 100644 --- a/http/transport/dump_round_tripper.go +++ b/http/transport/dump_round_tripper.go @@ -17,7 +17,7 @@ import ( ) // DumpRoundTripper dumps requests and responses in one log event. -// This is not part of te request logger to be able to filter dumps more easily +// This is not part of te request logger to be able to filter dumps more easily. type DumpRoundTripper struct { transport http.RoundTripper @@ -38,18 +38,22 @@ type dumpRoundTripperConfig struct { func roundTripConfigViaEnv() DumpRoundTripperOption { return func(rt *DumpRoundTripper) (*DumpRoundTripper, error) { var cfg dumpRoundTripperConfig - err := env.Parse(&cfg) - if err != nil { + + if err := env.Parse(&cfg); err != nil { return rt, fmt.Errorf("failed to parse dump round tripper environment: %w", err) } + for _, option := range cfg.Options { if !isDumpOptionValid(option) { return nil, fmt.Errorf("invalid dump option %q", option) } + rt.options[option] = true } + rt.blacklistAnyDumpPrefixes = cfg.BlacklistAnyDumpPrefixes rt.blacklistBodyDumpPrefixes = cfg.BlacklistBodyDumpPrefixes + return rt, nil } } @@ -60,46 +64,52 @@ func RoundTripConfig(dumpOptions ...string) DumpRoundTripperOption { if !isDumpOptionValid(option) { return nil, fmt.Errorf("invalid dump option %q", option) } + rt.options[option] = true } + return rt, nil } } // NewDumpRoundTripperEnv creates a new RoundTripper based on the configuration -// that is passed via environment variables +// that is passed via environment variables. func NewDumpRoundTripperEnv() *DumpRoundTripper { rt, err := NewDumpRoundTripper(roundTripConfigViaEnv()) if err != nil { log.Fatalf("failed to setup NewDumpRoundTripperEnv: %v", err) } + return rt } -// NewDumpRoundTripper return the roundtripper with configured options +// NewDumpRoundTripper return the roundtripper with configured options. func NewDumpRoundTripper(options ...DumpRoundTripperOption) (*DumpRoundTripper, error) { rt := &DumpRoundTripper{options: DumpOptions{}} + var err error + for _, option := range options { rt, err = option(rt) if err != nil { return rt, err } } + return rt, nil } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *DumpRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *DumpRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// AnyEnabled returns true if any logging is enabled +// AnyEnabled returns true if any logging is enabled. func (l *DumpRoundTripper) AnyEnabled() bool { return l.options.AnyEnabled(DumpRoundTripperOptionRequest, DumpRoundTripperOptionRequestHEX, DumpRoundTripperOptionResponse, DumpRoundTripperOptionResponseHEX) } @@ -108,16 +118,18 @@ func (l *DumpRoundTripper) ContainsBlacklistedPrefix(url *url.URL, blacklist []s if len(blacklist) == 0 { return false } + for _, prefix := range blacklist { // TODO (juf): Do benchmark and compare against using pre-constructed prefix-tree if strings.HasPrefix(url.String(), prefix) { return true } } + return false } -// RoundTrip executes a single HTTP transaction via Transport() +// RoundTrip executes a single HTTP transaction via Transport(). func (l *DumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { var redactor *redact.PatternRedactor @@ -156,6 +168,7 @@ func (l *DumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if options.IsEnabled(DumpRoundTripperOptionRequest) { dl = dl.Bytes(DumpRoundTripperOptionRequest, reqDump) } + if options.IsEnabled(DumpRoundTripperOptionRequestHEX) { dl = dl.Str(DumpRoundTripperOptionRequestHEX, hex.EncodeToString(reqDump)) } @@ -177,9 +190,11 @@ func (l *DumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if redactor != nil { respDump = []byte(redactor.Mask(string(respDump))) } + if options.IsEnabled(DumpRoundTripperOptionResponse) { dl = dl.Bytes(DumpRoundTripperOptionResponse, respDump) } + if options.IsEnabled(DumpRoundTripperOptionResponseHEX) { dl = dl.Str(DumpRoundTripperOptionResponseHEX, hex.EncodeToString(respDump)) } diff --git a/http/transport/dump_round_tripper_test.go b/http/transport/dump_round_tripper_test.go index e4b2bef1c..2cdd3aa39 100644 --- a/http/transport/dump_round_tripper_test.go +++ b/http/transport/dump_round_tripper_test.go @@ -5,6 +5,7 @@ package transport import ( "bytes" "context" + "net/http" "net/http/httptest" "os" "testing" @@ -24,13 +25,19 @@ func TestNewDumpRoundTripperEnv(t *testing.T) { rt := NewDumpRoundTripperEnv() assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err := rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) assert.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Equal(t, "", out.String()) }) } @@ -40,8 +47,16 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { ctx := log.Output(out).WithContext(context.Background()) require.NotPanics(t, func() { - defer os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX")) - os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", "https://please-ignore-me") + oldEnv := os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX") + + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", oldEnv) + assert.NoError(t, err) + }() + + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", "https://please-ignore-me") + require.NoError(t, err) + rt, err := NewDumpRoundTripper( roundTripConfigViaEnv(), RoundTripConfig( @@ -54,12 +69,20 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { require.NoError(t, err) assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) - assert.NoError(t, err) + { + resp, err := rt.RoundTrip(req) + assert.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + } assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\nFoo"`) @@ -72,11 +95,19 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { assert.Equal(t, "", out.String()) - reqWithPrefix := httptest.NewRequest("GET", "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) + reqWithPrefix := httptest.NewRequest(http.MethodGet, "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) reqWithPrefix = reqWithPrefix.WithContext(ctx) - _, err = rt.RoundTrip(reqWithPrefix) - assert.NoError(t, err) + { + resp, err := rt.RoundTrip(reqWithPrefix) + assert.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + } + assert.Empty(t, out.String()) }) } @@ -85,9 +116,18 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { out := &bytes.Buffer{} ctx := log.Output(out).WithContext(context.Background()) + log.Println(os.Environ()) require.NotPanics(t, func() { - defer os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX")) - os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", "https://please-ignore-me") + oldEnv := os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX") + + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", oldEnv) + assert.NoError(t, err) + }() + + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", "https://please-ignore-me") + require.NoError(t, err) + rt, err := NewDumpRoundTripper( roundTripConfigViaEnv(), RoundTripConfig( @@ -100,12 +140,20 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { require.NoError(t, err) assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) - assert.NoError(t, err) + { + resp, err := rt.RoundTrip(req) + assert.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + } assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\nFoo"`) @@ -118,11 +166,18 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { assert.Equal(t, "", out.String()) - reqWithPrefix := httptest.NewRequest("GET", "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) + reqWithPrefix := httptest.NewRequest(http.MethodGet, "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) reqWithPrefix = reqWithPrefix.WithContext(ctx) - _, err = rt.RoundTrip(reqWithPrefix) - assert.NoError(t, err) + { + resp, err := rt.RoundTrip(reqWithPrefix) + assert.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + } assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET https://please-ignore-me.org/foo/ HTTP/1.1\r\n\r\n"`) @@ -148,13 +203,19 @@ func TestNewDumpRoundTripper(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) assert.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\nFoo"`) assert.Contains(t, out.String(), `"request-hex":"474554202f666f6f20485454502f312e310d0a486f73743a206578616d706c652e636f6d0d0a0d0a466f6f"`) @@ -176,14 +237,20 @@ func TestNewDumpRoundTripperRedacted(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo DE12345678909876543210 bar")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo DE12345678909876543210 bar")) ctx = redact.Default.WithContext(ctx) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) assert.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\nFoo ******************3210 bar"`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\nContent-Length: 0\r\n\r\n"`) @@ -203,14 +270,20 @@ func TestNewDumpRoundTripperRedactedBasicAuth(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Authorization: Basic ZGVtbzpwQDU1dzByZA==")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Authorization: Basic ZGVtbzpwQDU1dzByZA==")) ctx = redact.Default.WithContext(ctx) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) assert.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n*************************************ZA=="`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\nContent-Length: 0\r\n\r\n"`) @@ -229,13 +302,19 @@ func TestNewDumpRoundTripperSimple(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) assert.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\nContent-Length: 0\r\n\r\n"`) @@ -255,12 +334,17 @@ func TestNewDumpRoundTripperContextOptionsOverwrite(t *testing.T) { out := &bytes.Buffer{} ctx := log.Output(out).WithContext(context.Background()) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) require.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\nContent-Length: 0\r\n\r\n"`) @@ -276,13 +360,20 @@ func TestNewDumpRoundTripperContextOptionsOverwrite(t *testing.T) { WithDumpOption(DumpRoundTripperOptionResponse, false), ) require.NoError(t, err) + ctx = CtxWithDumpRoundTripperOptions(ctx, ctxDumpOptions) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) require.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + require.Empty(t, out.String()) // Both request and response were disabled for this request }) } @@ -301,12 +392,17 @@ func TestNewDumpRoundTripperContextOptionsOverwriteBody(t *testing.T) { out := &bytes.Buffer{} ctx := log.Output(out).WithContext(context.Background()) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) require.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\nFoo"`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\n\r\ntest body"`) @@ -321,14 +417,20 @@ func TestNewDumpRoundTripperContextOptionsOverwriteBody(t *testing.T) { WithDumpOption(DumpRoundTripperOptionBody, false), ) require.NoError(t, err) + ctx = CtxWithDumpRoundTripperOptions(ctx, ctxDumpOptions) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) - _, err = rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) require.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Contains(t, out.String(), `"level":"debug"`) assert.Contains(t, out.String(), `"request":"GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"`) assert.Contains(t, out.String(), `"response":"HTTP/0.0 000 status code 0\r\nContent-Length: 0\r\n\r\n"`) diff --git a/http/transport/external_dependency_round_tripper.go b/http/transport/external_dependency_round_tripper.go index 788d92a17..0a9e36934 100644 --- a/http/transport/external_dependency_round_tripper.go +++ b/http/transport/external_dependency_round_tripper.go @@ -10,7 +10,7 @@ import ( ) // ExternalDependencyRoundTripper greps external dependency headers and -// attach them to the currect context +// attach them to the currect context. type ExternalDependencyRoundTripper struct { name string transport http.RoundTripper @@ -20,17 +20,17 @@ func NewExternalDependencyRoundTripper(name string) *ExternalDependencyRoundTrip return &ExternalDependencyRoundTripper{name: name} } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *ExternalDependencyRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *ExternalDependencyRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a single HTTP transaction via Transport() +// RoundTrip executes a single HTTP transaction via Transport(). func (l *ExternalDependencyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { start := time.Now() resp, err := l.Transport().RoundTrip(req) diff --git a/http/transport/external_dependency_round_tripper_test.go b/http/transport/external_dependency_round_tripper_test.go index 3fbc5b060..9890510a9 100644 --- a/http/transport/external_dependency_round_tripper_test.go +++ b/http/transport/external_dependency_round_tripper_test.go @@ -8,8 +8,10 @@ import ( "net/http/httptest" "testing" - "github.com/pace/bricks/http/middleware" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pace/bricks/http/middleware" ) type edRoundTripperMock struct { @@ -24,9 +26,10 @@ func (m *edRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error func TestExternalDependencyRoundTripper(t *testing.T) { var edc middleware.ExternalDependencyContext + ctx := middleware.ContextWithExternalDependency(context.Background(), &edc) - r := httptest.NewRequest("GET", "http://example.com/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) r = r.WithContext(ctx) mock := &edRoundTripperMock{ @@ -38,17 +41,18 @@ func TestExternalDependencyRoundTripper(t *testing.T) { } lrt := &ExternalDependencyRoundTripper{transport: mock} - _, err := lrt.RoundTrip(r) - assert.NoError(t, err) + _, err := lrt.RoundTrip(r) //nolint:bodyclose + require.NoError(t, err) assert.EqualValues(t, "test1:123,test2:53", edc.String()) } func TestExternalDependencyRoundTripperWithName(t *testing.T) { var edc middleware.ExternalDependencyContext + ctx := middleware.ContextWithExternalDependency(context.Background(), &edc) - r := httptest.NewRequest("GET", "http://example.com/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) r = r.WithContext(ctx) mock := &edRoundTripperMock{ @@ -60,8 +64,8 @@ func TestExternalDependencyRoundTripperWithName(t *testing.T) { } lrt := &ExternalDependencyRoundTripper{name: "ext", transport: mock} - _, err := lrt.RoundTrip(r) - assert.NoError(t, err) + _, err := lrt.RoundTrip(r) //nolint:bodyclose + require.NoError(t, err) assert.EqualValues(t, "ext:0,test1:123,test2:53", edc.String()) } diff --git a/http/transport/locale_round_tripper.go b/http/transport/locale_round_tripper.go index b742617dd..55523ced4 100644 --- a/http/transport/locale_round_tripper.go +++ b/http/transport/locale_round_tripper.go @@ -8,22 +8,22 @@ import ( "github.com/pace/bricks/locale" ) -// LocaleRoundTripper implements a chainable round tripper for locale forwarding +// LocaleRoundTripper implements a chainable round tripper for locale forwarding. type LocaleRoundTripper struct { transport http.RoundTripper } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *LocaleRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *LocaleRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a HTTP request with logging +// RoundTrip executes a HTTP request with logging. func (l *LocaleRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { loc, ok := locale.FromCtx(req.Context()) if ok { diff --git a/http/transport/locale_round_tripper_test.go b/http/transport/locale_round_tripper_test.go index cc27312d4..fe1e25c06 100644 --- a/http/transport/locale_round_tripper_test.go +++ b/http/transport/locale_round_tripper_test.go @@ -8,10 +8,10 @@ import ( "net/http/httputil" "testing" - "github.com/pace/bricks/locale" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/locale" ) type roundTripperMock struct { @@ -28,10 +28,11 @@ func TestLocaleRoundTrip(t *testing.T) { lrt := &LocaleRoundTripper{transport: mock} l := locale.NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) - lrt.RoundTrip(r.WithContext(locale.WithLocale(context.Background(), l))) // nolint: errcheck + _, err = lrt.RoundTrip(r.WithContext(locale.WithLocale(context.Background(), l))) //nolint:bodyclose + require.NoError(t, err) lctx, ok := locale.FromCtx(mock.r.Context()) require.True(t, ok) diff --git a/http/transport/logging_round_tripper.go b/http/transport/logging_round_tripper.go index b6465f236..7c2b05397 100644 --- a/http/transport/logging_round_tripper.go +++ b/http/transport/logging_round_tripper.go @@ -11,36 +11,37 @@ import ( "github.com/pace/bricks/maintenance/log" ) -// LoggingRoundTripper implements a chainable round tripper for logging +// LoggingRoundTripper implements a chainable round tripper for logging. type LoggingRoundTripper struct { transport http.RoundTripper } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *LoggingRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *LoggingRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a HTTP request with logging +// RoundTrip executes a HTTP request with logging. func (l *LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() startTime := time.Now() - le := log.Ctx(ctx).Debug(). - Str("url", req.URL.String()). - Str("method", req.Method). - Str("sentry:type", "http"). - Str("sentry:category", "http") + le := log.Ctx(ctx).Debug(). //nolint:zerologlint + Str("url", req.URL.String()). + Str("method", req.Method). + Str("sentry:type", "http"). + Str("sentry:category", "http") resp, err := l.Transport().RoundTrip(req) dur := float64(time.Since(startTime)) / float64(time.Millisecond) le = le.Float64("duration", dur) attempt := attemptFromCtx(ctx) + if attempt > 0 { le = le.Int("attempt", int(attempt)) } diff --git a/http/transport/logging_round_tripper_test.go b/http/transport/logging_round_tripper_test.go index 02bf4555b..372d6d129 100644 --- a/http/transport/logging_round_tripper_test.go +++ b/http/transport/logging_round_tripper_test.go @@ -5,11 +5,14 @@ package transport import ( "bytes" "context" + "net/http" "net/http/httptest" "net/url" "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/pace/bricks/maintenance/log" ) @@ -19,26 +22,34 @@ func TestLoggingRoundTripper(t *testing.T) { ctx := log.Output(out).WithContext(context.Background()) // create request with context and url - req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "/foo", nil).WithContext(ctx) + url, err := url.Parse("http://example.com/foo") if err != nil { panic(err) } + req.URL = url t.Run("Without retries", func(t *testing.T) { l := &LoggingRoundTripper{} - l.SetTransport(&transportWithResponse{statusCode: 200}) + l.SetTransport(&transportWithResponse{statusCode: http.StatusOK}) - _, err = l.RoundTrip(req) + resp, err := l.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + got := out.String() if !strings.Contains(got, "duration") { t.Errorf("Expected duration to be contained in log output, got %v", got) } + if strings.Contains(got, "retries") { t.Errorf("Expected retries to not be contained in log output, got %v", got) } @@ -54,13 +65,19 @@ func TestLoggingRoundTripper(t *testing.T) { l := Chain(NewDefaultRetryRoundTripper(), &LoggingRoundTripper{}) l.Final(&retriedTransport{statusCodes: []int{502, 503, 408, 202}}) - _, err = l.RoundTrip(req) + resp, err := l.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + got := out.String() exs := []string{`"level":"debug"`, `"url":"http://example.com/foo"`, `"method":"GET"`, `"status_code":200`, `"message":"HTTP GET example.com"`, `"attempt":3`} + for _, ex := range exs { if !strings.Contains(got, ex) { t.Errorf("Expected %v to be contained in log output, got %v", ex, got) diff --git a/http/transport/request_id.go b/http/transport/request_id.go index e4945cc19..3aae8c4a9 100644 --- a/http/transport/request_id.go +++ b/http/transport/request_id.go @@ -8,27 +8,28 @@ import ( "github.com/pace/bricks/maintenance/log" ) -// RequestIDRoundTripper implements a chainable round tripper for setting the Request-Source header +// RequestIDRoundTripper implements a chainable round tripper for setting the Request-Source header. type RequestIDRoundTripper struct { transport http.RoundTripper SourceName string } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *RequestIDRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *RequestIDRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a single HTTP transaction via Transport() +// RoundTrip executes a single HTTP transaction via Transport(). func (l *RequestIDRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() if reqID := log.RequestIDFromContext(ctx); reqID != "" { req.Header.Set("Request-Id", reqID) } + return l.Transport().RoundTrip(req) } diff --git a/http/transport/request_id_test.go b/http/transport/request_id_test.go index e5f6ff43a..03df3a927 100644 --- a/http/transport/request_id_test.go +++ b/http/transport/request_id_test.go @@ -8,9 +8,10 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestRequestIDRoundTripper(t *testing.T) { @@ -18,9 +19,16 @@ func TestRequestIDRoundTripper(t *testing.T) { rt.SetTransport(&transportWithResponse{}) t.Run("without req_id", func(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) - _, err := rt.RoundTrip(req) - assert.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Empty(t, req.Header["Request-Id"]) }) @@ -34,17 +42,23 @@ func TestRequestIDRoundTripper(t *testing.T) { require.Equal(t, ID, log.RequestID(r)) require.Equal(t, ID, log.RequestIDFromContext(r.Context())) - r1 := httptest.NewRequest("GET", "/foo", nil) + r1 := httptest.NewRequest(http.MethodGet, "/foo", nil) r1 = r1.WithContext(r.Context()) - _, err := rt.RoundTrip(r1) - assert.NoError(t, err) + resp, err := rt.RoundTrip(r1) + require.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Equal(t, []string{ID}, r1.Header["Request-Id"]) w.WriteHeader(http.StatusNoContent) }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Request-Id", ID) r.ServeHTTP(rec, req) assert.Equal(t, http.StatusNoContent, rec.Code) diff --git a/http/transport/request_source_round_tripper.go b/http/transport/request_source_round_tripper.go index 30b400f3b..da81a2f59 100644 --- a/http/transport/request_source_round_tripper.go +++ b/http/transport/request_source_round_tripper.go @@ -6,23 +6,23 @@ import ( "net/http" ) -// RequestSourceRoundTripper implements a chainable round tripper for setting the Request-Source header +// RequestSourceRoundTripper implements a chainable round tripper for setting the Request-Source header. type RequestSourceRoundTripper struct { transport http.RoundTripper SourceName string } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *RequestSourceRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *RequestSourceRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a single HTTP transaction via Transport() +// RoundTrip executes a single HTTP transaction via Transport(). func (l *RequestSourceRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Set("Request-Source", l.SourceName) return l.Transport().RoundTrip(req) diff --git a/http/transport/request_source_round_tripper_test.go b/http/transport/request_source_round_tripper_test.go index 3305ee1d2..12cb260f4 100644 --- a/http/transport/request_source_round_tripper_test.go +++ b/http/transport/request_source_round_tripper_test.go @@ -3,19 +3,27 @@ package transport import ( + "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequestSourceRoundTripper(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) rt := RequestSourceRoundTripper{SourceName: "foobar"} rt.SetTransport(&transportWithResponse{}) - _, err := rt.RoundTrip(req) - assert.NoError(t, err) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + assert.Equal(t, []string{"foobar"}, req.Header["Request-Source"]) } diff --git a/http/transport/retry_round_tripper.go b/http/transport/retry_round_tripper.go index ffa30fe44..16f5bc871 100644 --- a/http/transport/retry_round_tripper.go +++ b/http/transport/retry_round_tripper.go @@ -15,7 +15,7 @@ import ( const maxRetries = 9 -// RetryRoundTripper implements a chainable round tripper for retrying requests +// RetryRoundTripper implements a chainable round tripper for retrying requests. type RetryRoundTripper struct { retryTransport *rehttp.Transport transport http.RoundTripper @@ -24,17 +24,14 @@ type RetryRoundTripper struct { // RetryNetErr retries errors returned by the 'net' package. func RetryNetErr() rehttp.RetryFn { return func(attempt rehttp.Attempt) bool { - if _, isNetError := attempt.Error.(*net.OpError); isNetError { - return true - } - return false + return errors.Is(attempt.Error, &net.OpError{}) } } -// RetryEOFErr retries only when the error is EOF +// RetryEOFErr retries only when the error is EOF. func RetryEOFErr() rehttp.RetryFn { return func(attempt rehttp.Attempt) bool { - return attempt.Error == io.EOF + return errors.Is(attempt.Error, io.EOF) } } @@ -81,23 +78,24 @@ func (rt *retryWrappedTransport) RoundTrip(r *http.Request) (*http.Response, err return rt.transport.RoundTrip(r) } -// Transport returns the RoundTripper to make HTTP requests +// Transport returns the RoundTripper to make HTTP requests. func (l *RetryRoundTripper) Transport() http.RoundTripper { return l.transport } -// SetTransport sets the RoundTripper to make HTTP requests +// SetTransport sets the RoundTripper to make HTTP requests. func (l *RetryRoundTripper) SetTransport(rt http.RoundTripper) { l.transport = rt } -// RoundTrip executes a HTTP request with retrying +// RoundTrip executes a HTTP request with retrying. func (l *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { retryTransport := *l.retryTransport wrappedTransport := &retryWrappedTransport{ transport: transportWithAttempt(l.Transport()), } retryTransport.RoundTripper = wrappedTransport + resp, err := retryTransport.RoundTrip(req) if err != nil { return nil, err diff --git a/http/transport/retry_round_tripper_test.go b/http/transport/retry_round_tripper_test.go index d1a26494c..078d786db 100644 --- a/http/transport/retry_round_tripper_test.go +++ b/http/transport/retry_round_tripper_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -34,7 +35,7 @@ func TestRetryRoundTripper(t *testing.T) { name: "Successful response after some retries", args: args{ requestBody: []byte(`{"key":"value""}`), - statuses: []int{408, 502, 503, 504, 200}, + statuses: []int{408, 502, 503, 504, http.StatusOK}, }, wantRetries: 5, }, @@ -69,7 +70,7 @@ func TestRetryRoundTripper(t *testing.T) { name: "Exceed retries", args: args{ requestBody: []byte(`{"key":"value""}`), - statuses: []int{408, 502, 503, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 200}, + statuses: []int{408, 502, 503, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, http.StatusOK}, }, wantRetries: 10, wantErr: ErrRetryFailed, @@ -84,7 +85,8 @@ func TestRetryRoundTripper(t *testing.T) { } rt.SetTransport(tr) - req := httptest.NewRequest("GET", "/foo", bytes.NewReader(tt.args.requestBody)) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewReader(tt.args.requestBody)) + resp, err := rt.RoundTrip(req.WithContext(context.Background())) require.Equal(t, tt.wantRetries, tr.attempts) @@ -93,8 +95,14 @@ func TestRetryRoundTripper(t *testing.T) { require.ErrorIs(t, err, tt.wantErr) return } + require.NoError(t, err) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, string(tt.args.requestBody), string(body)) @@ -120,6 +128,7 @@ func (t *retriedTransport) RoundTrip(req *http.Request) (*http.Response, error) if t.err != nil { return nil, fmt.Errorf("%w", t.err) } + readAll, _ := io.ReadAll(req.Body) body := io.NopCloser(bytes.NewReader(readAll)) resp := &http.Response{Body: body, StatusCode: t.statusCodes[t.attempts]} diff --git a/internal/service/generate/cmds.go b/internal/service/generate/cmds.go index a0706eec8..720dbf7d8 100644 --- a/internal/service/generate/cmds.go +++ b/internal/service/generate/cmds.go @@ -15,13 +15,13 @@ import ( const errorsPkg = "github.com/pace/bricks/maintenance/errors" // CommandOptions are applied when generating the different -// microservice commands +// microservice commands. type CommandOptions struct { DaemonName string ControlName string } -// NewCommandOptions generate command names using given name +// NewCommandOptions generate command names using given name. func NewCommandOptions(name string) CommandOptions { return CommandOptions{ DaemonName: name + "d", @@ -30,7 +30,7 @@ func NewCommandOptions(name string) CommandOptions { } // Commands generates the microservice commands based of -// the given path +// the given path. func Commands(path string, options CommandOptions) { // Create directories dirs := []string{ @@ -38,15 +38,14 @@ func Commands(path string, options CommandOptions) { filepath.Join(path, "cmd", options.ControlName), } for _, dir := range dirs { - err := os.MkdirAll(dir, 0o770) // nolint: gosec - if err != nil { + if err := os.MkdirAll(dir, 0o750); err != nil { log.Fatal(fmt.Printf("Failed to create dir %s: %v", dir, err)) } } // Create commands files for _, dir := range dirs { - f, err := os.Create(filepath.Join(dir, "main.go")) + f, err := os.Create(filepath.Join(dir, "main.go")) //nolint:gosec if err != nil { log.Fatal(err) } @@ -59,6 +58,7 @@ func Commands(path string, options CommandOptions) { } else { generateControlMain(code, cmdName) } + _, err = f.WriteString(copyright()) if err != nil { log.Fatal(err) @@ -98,10 +98,11 @@ func generateControlMain(f *jen.File, cmdName string) { jen.Qual("fmt", "Printf").Call(jen.Lit(cmdName))) } -// copyright generates copyright statement +// copyright generates copyright statement. func copyright() string { stmt := "" now := time.Now() stmt += fmt.Sprintf("// Copyright © %04d by PACE Telematics GmbH. All rights reserved.\n", now.Year()) + return stmt } diff --git a/internal/service/generate/dockerfile.go b/internal/service/generate/dockerfile.go index 9418589ba..c881badea 100644 --- a/internal/service/generate/dockerfile.go +++ b/internal/service/generate/dockerfile.go @@ -9,22 +9,21 @@ import ( ) // DockerfileOptions configure the output of the generated docker -// file +// file. type DockerfileOptions struct { Name string Commands CommandOptions } // Dockerfile generate a dockerfile using the given options -// for specified path +// for specified path. func Dockerfile(path string, options DockerfileOptions) { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } - err = dockerTemplate.Execute(f, options) - if err != nil { + if err := dockerTemplate.Execute(f, options); err != nil { log.Fatal(err) } } diff --git a/internal/service/generate/error.go b/internal/service/generate/error.go index 7411b22ac..d1bcb317a 100644 --- a/internal/service/generate/error.go +++ b/internal/service/generate/error.go @@ -9,15 +9,16 @@ import ( "github.com/pace/bricks/internal/service/generate/errordefinition/generator" ) -// ErrorDefinitionFileOptions options that change the rendering of the error definition file +// ErrorDefinitionFileOptions options that change the rendering of the error definition file. type ErrorDefinitionFileOptions struct { PkgName, Path, Source string } -// ErrorDefinitionFile builds a file with error definitions +// ErrorDefinitionFile builds a file with error definitions. func ErrorDefinitionFile(options ErrorDefinitionFileOptions) { // generate error definition g := generator.Generator{} + result, err := g.BuildSource(options.Source, options.Path, options.PkgName) if err != nil { log.Fatal(err) @@ -28,6 +29,7 @@ func ErrorDefinitionFile(options ErrorDefinitionFileOptions) { func ErrorDefinitionsMarkdown(options ErrorDefinitionFileOptions) { g := generator.Generator{} + result, err := g.BuildMarkdown(options.Source) if err != nil { log.Fatal(err) @@ -38,11 +40,16 @@ func ErrorDefinitionsMarkdown(options ErrorDefinitionFileOptions) { func writeResult(result, path string) { // create file - file, err := os.Create(path) + file, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } - defer file.Close() // nolint: errcheck + + defer func() { + if err := file.Close(); err != nil { + log.Printf("failed closing file body: %v\n", err) + } + }() // write file _, err = file.WriteString(result) diff --git a/internal/service/generate/errordefinition/generator/generate.go b/internal/service/generate/errordefinition/generator/generate.go index 5fb79ac44..03ea8f750 100644 --- a/internal/service/generate/errordefinition/generator/generate.go +++ b/internal/service/generate/errordefinition/generator/generate.go @@ -27,6 +27,7 @@ type Generator struct { func loadDefinitionData(source string) ([]byte, error) { var data []byte + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { loc, err := url.Parse(source) if err != nil { @@ -40,7 +41,8 @@ func loadDefinitionData(source string) ([]byte, error) { } else { // read definition file from disk var err error - data, err = os.ReadFile(source) // nolint: gosec + + data, err = os.ReadFile(source) //nolint:gosec if err != nil { return nil, err } @@ -54,17 +56,23 @@ func loadDefinitionDataFromURI(url *url.URL) ([]byte, error) { if err != nil { return nil, err } - defer resp.Body.Close() // nolint: errcheck + + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Fprintf(os.Stderr, "failed closing response body: %v", err) + } + }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } + return body, nil } // BuildSource generates the go code in the specified path with specified package name -// based on the passed schema source (url or file path) +// based on the passed schema source (url or file path). func (g *Generator) BuildSource(source, packagePath, packageName string) (string, error) { data, err := loadDefinitionData(source) if err != nil { @@ -73,8 +81,8 @@ func (g *Generator) BuildSource(source, packagePath, packageName string) (string // parse definition var errors runtime.Errors - err = json.Unmarshal(data, &errors) - if err != nil { + + if err := json.Unmarshal(data, &errors); err != nil { return "", err } @@ -89,16 +97,15 @@ func (g *Generator) BuildDefinitions(errors runtime.Errors, packagePath, package // create a error code const for easier runtime error comparison - var constObjects []jen.Code - for _, jsonError := range errors { + constObjects := make([]jen.Code, 0) + for _, jsonError := range errors { // skip example if given if jsonError.Code == "EXAMPLE" { continue } constObjects = append(constObjects, jen.Id(fmt.Sprintf("ERR_CODE_%s", jsonError.Code)).Op("=").Lit(jsonError.Code)) - } if len(constObjects) > 0 { @@ -106,7 +113,6 @@ func (g *Generator) BuildDefinitions(errors runtime.Errors, packagePath, package } for _, jsonError := range errors { - // skip example if given if jsonError.Code == "EXAMPLE" { continue diff --git a/internal/service/generate/errordefinition/generator/markdown.go b/internal/service/generate/errordefinition/generator/markdown.go index 56f5af068..b4dcebd20 100644 --- a/internal/service/generate/errordefinition/generator/markdown.go +++ b/internal/service/generate/errordefinition/generator/markdown.go @@ -37,8 +37,8 @@ func (g *Generator) BuildMarkdown(source string) (string, error) { func (g *Generator) parseDefinitions(data []byte) (ErrorDefinitions, error) { var parsedData []ErrorDefinition - err := json.Unmarshal(data, &parsedData) - if err != nil { + + if err := json.Unmarshal(data, &parsedData); err != nil { return nil, err } @@ -70,12 +70,14 @@ func (g *Generator) generateMarkdown(eds ErrorDefinitions) (string, error) { if err != nil { return "", err } + _, err = output.WriteString(`|Code|Title| |-----------|-----------| `) if err != nil { panic(err) } + for _, detail := range details { _, err := output.WriteString(fmt.Sprintf("|%s|%s|\n", detail.Code, detail.Title)) if err != nil { diff --git a/internal/service/generate/makefile.go b/internal/service/generate/makefile.go index cec4de7bc..f5ca3a07e 100644 --- a/internal/service/generate/makefile.go +++ b/internal/service/generate/makefile.go @@ -9,21 +9,20 @@ import ( ) // MakefileOptions options that change the rendering -// of the makefile +// of the makefile. type MakefileOptions struct { Name string } // Makefile generates a with given options for the -// specified path +// specified path. func Makefile(path string, options MakefileOptions) { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } - err = makefileTemplate.Execute(f, options) - if err != nil { + if err := makefileTemplate.Execute(f, options); err != nil { log.Fatal(err) } } diff --git a/internal/service/generate/rest.go b/internal/service/generate/rest.go index c06857bd3..e1269073f 100644 --- a/internal/service/generate/rest.go +++ b/internal/service/generate/rest.go @@ -9,15 +9,16 @@ import ( "github.com/pace/bricks/http/jsonapi/generator" ) -// RestOptions options to respect when generating the rest api +// RestOptions options to respect when generating the rest api. type RestOptions struct { PkgName, Path, Source string } -// Rest builds a jsonapi rest api +// Rest builds a jsonapi rest api. func Rest(options RestOptions) { // generate jsonapi g := generator.Generator{} + result, err := g.BuildSource(options.Source, options.Path, options.PkgName) if err != nil { log.Fatal(err) @@ -28,7 +29,12 @@ func Rest(options RestOptions) { if err != nil { log.Fatal(err) } - defer file.Close() // nolint: errcheck + + defer func() { + if err := file.Close(); err != nil { + log.Printf("failed closing file body: %v\n", err) + } + }() // write file _, err = file.WriteString(result) diff --git a/internal/service/helper.go b/internal/service/helper.go index 5d99df655..dafed494a 100644 --- a/internal/service/helper.go +++ b/internal/service/helper.go @@ -12,17 +12,17 @@ import ( "path/filepath" ) -// PaceBase for all go projects +// PaceBase for all go projects. const PaceBase = "git.pace.cloud/pace" -// ServiceBase for all go microservices +// ServiceBase for all go microservices. const ServiceBase = "web/service" -// GitLabTemplate git clone template for cloning repositories +// GitLabTemplate git clone template for cloning repositories. const GitLabTemplate = "git@git.pace.cloud:pace/web/service/%s.git" // GoPath returns the gopath for the current system, -// uses GOPATH env and fallback to default go dir +// uses GOPATH env and fallback to default go dir. func GoPath() string { path, ok := os.LookupEnv("GOPATH") if !ok { @@ -30,6 +30,7 @@ func GoPath() string { if err != nil { log.Fatal(err) } + return filepath.Join(usr.HomeDir, "go") } @@ -37,7 +38,7 @@ func GoPath() string { } // PacePath returns the pace path for the current system, -// uses PACE_PATH env and fallback to default go dir +// uses PACE_PATH env and fallback to default go dir. func PacePath() string { path, ok := os.LookupEnv("PACE_PATH") if !ok { @@ -45,26 +46,27 @@ func PacePath() string { if err != nil { log.Fatal(err) } + return filepath.Join(usr.HomeDir, "PACE") } return path } -// GoServicePath returns the path of the go service for given name +// GoServicePath returns the path of the go service for given name. func GoServicePath(name string) (string, error) { return filepath.Abs(filepath.Join(PacePath(), ServiceBase, name)) } -// GoServicePackagePath returns a go package path for given service name +// GoServicePackagePath returns a go package path for given service name. func GoServicePackagePath(name string) string { return filepath.Join(PaceBase, ServiceBase, name) } -// AutoInstall cmdName if not installed already using go get -u goGetPath +// AutoInstall cmdName if not installed already using go get -u goGetPath. func AutoInstall(cmdName, goGetPath string) { if _, err := os.Stat(GoBinCommand(cmdName)); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Installing %s using: go get -u %s\n", cmdName, goGetPath) // nolint: errcheck + fmt.Fprintf(os.Stderr, "Installing %s using: go get -u %s\n", cmdName, goGetPath) // assume error means no file SimpleExec("go", "get", "-u", goGetPath) } else if err != nil { @@ -72,44 +74,44 @@ func AutoInstall(cmdName, goGetPath string) { } } -// GoBinCommand returns the path to a binary installed in the gopath +// GoBinCommand returns the path to a binary installed in the gopath. func GoBinCommand(cmdName string) string { return filepath.Join(GoPath(), "bin", cmdName) } -// SimpleExec executes the command and uses the parent process STDIN,STDOUT,STDERR +// SimpleExec executes the command and uses the parent process STDIN,STDOUT,STDERR. func SimpleExec(cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - err := cmd.Run() - if err != nil { + + if err := cmd.Run(); err != nil { log.Fatal(err) } } -// SimpleExecInPath executes the command and uses the parent process STDIN,STDOUT,STDERR in passed dir +// SimpleExecInPath executes the command and uses the parent process STDIN,STDOUT,STDERR in passed dir. func SimpleExecInPath(dir, cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Dir = dir - err := cmd.Run() - if err != nil { + + if err := cmd.Run(); err != nil { log.Fatal(err) } } -// GoBinCommandText writes the command output to the passed writer +// GoBinCommandText writes the command output to the passed writer. func GoBinCommandText(w io.Writer, cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = w cmd.Stderr = os.Stderr - err := cmd.Run() - if err != nil { + + if err := cmd.Run(); err != nil { log.Fatal(err) } } diff --git a/internal/service/new.go b/internal/service/new.go index 51fba793a..3c480649e 100644 --- a/internal/service/new.go +++ b/internal/service/new.go @@ -12,12 +12,12 @@ import ( ) // NewOptions collection of options to apply while or -// after the creation of the new project +// after the creation of the new project. type NewOptions struct { RestSource string // url or path to OpenAPIv3 (json:api) specification } -// New creates a new directory in the go path +// New creates a new directory in the go path. func New(name string, options NewOptions) { // get dir for the service dir, err := GoServicePath(name) @@ -32,8 +32,8 @@ func New(name string, options NewOptions) { // add REST API if there was a source specified if options.RestSource != "" { restDir := filepath.Join(dir, "internal", "http", "rest") - err := os.MkdirAll(restDir, 0o770) // nolint: gosec - if err != nil { + + if err := os.MkdirAll(restDir, 0o750); err != nil { log.Fatal(fmt.Printf("Failed to generate dir for rest api %s: %v", restDir, err)) } diff --git a/internal/service/service.go b/internal/service/service.go index 7c3d65b68..b06d67130 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -11,7 +11,7 @@ import ( "strings" ) -// Clone the service into pace path +// Clone the service into pace path. func Clone(name string) { // get dir for the service dir, err := GoServicePath(name) @@ -22,24 +22,23 @@ func Clone(name string) { SimpleExec("git", "clone", fmt.Sprintf(GitLabTemplate, name), dir) } -// Path prints the path of the service identified by name to STDOUT +// Path prints the path of the service identified by name to STDOUT. func Path(name string) { // get dir for the service dir, err := GoServicePath(name) if err != nil { log.Fatal(err) } + fmt.Println(dir) } // Edit the service with given name in favorite editor, defined -// by env PACE_EDITOR or EDITOR +// by env PACE_EDITOR or EDITOR. func Edit(name string) { editor, ok := os.LookupEnv("PACE_EDITOR") - if !ok { editor, ok = os.LookupEnv("EDITOR") - if !ok { log.Fatal("No $PACE_EDITOR or $EDITOR defined!") } @@ -54,14 +53,14 @@ func Edit(name string) { SimpleExec(editor, dir) } -// RunOptions fallback cmdName and additional arguments for the run cmd +// RunOptions fallback cmdName and additional arguments for the run cmd. type RunOptions struct { CmdName string // alternative name for the command of the service Args []string // rest of arguments } // Run the service daemon for the given name or use the optional -// cmdname instead +// cmdname instead. func Run(name string, options RunOptions) { // get dir for the service dir, err := GoServicePath(name) @@ -76,6 +75,7 @@ func Run(name string, options RunOptions) { } else { args, err = filepath.Glob(filepath.Join(dir, fmt.Sprintf("cmd/%s/*.go", options.CmdName))) } + if err != nil { log.Fatal(err) } @@ -86,12 +86,12 @@ func Run(name string, options RunOptions) { SimpleExec("go", args...) } -// TestOptions options to respect when starting a test +// TestOptions options to respect when starting a test. type TestOptions struct { GoConvey bool } -// Test execute the gorich or goconvey test runners +// Test execute the gorich or goconvey test runners. func Test(name string, options TestOptions) { if options.GoConvey { AutoInstall("goconvey", "github.com/smartystreets/goconvey") @@ -115,12 +115,14 @@ func Test(name string, options TestOptions) { } } -// Lint executes golint or installes if not already installed +// Lint executes golint or installs if not already installed. func Lint(name string) { AutoInstall("golint", "golang.org/x/lint/golint") var buf bytes.Buffer + GoBinCommandText(&buf, "go", "list", filepath.Join(GoServicePackagePath(name), "...")) + paths := strings.Split(buf.String(), "\n") // start go run diff --git a/locale/cfg.go b/locale/cfg.go index 399d80f77..3dfbef238 100644 --- a/locale/cfg.go +++ b/locale/cfg.go @@ -16,8 +16,7 @@ type config struct { var cfg config func init() { - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse environment: %v", err) } } diff --git a/locale/context.go b/locale/context.go index 07c69efa2..db9f3a51a 100644 --- a/locale/context.go +++ b/locale/context.go @@ -6,13 +6,13 @@ import ( "context" ) -// ctx private key type to seal the access +// ctx private key type to seal the access. type ctx string -// tokenKey private key to seal the access +// tokenKey private key to seal the access. var tokenKey = ctx("locale") -// WithLocale creates a new context with the passed locale +// WithLocale creates a new context with the passed locale. func WithLocale(ctx context.Context, locale *Locale) context.Context { return context.WithValue(ctx, tokenKey, locale) } @@ -24,12 +24,14 @@ func FromCtx(ctx context.Context) (*Locale, bool) { if val == nil { return new(Locale), false } + l, ok := val.(*Locale) + return l, ok } // ContextTransfer sources the locale from the sourceCtx -// and returns a new context based on the targetCtx +// and returns a new context based on the targetCtx. func ContextTransfer(sourceCtx context.Context, targetCtx context.Context) context.Context { l, _ := FromCtx(sourceCtx) return WithLocale(targetCtx, l) diff --git a/locale/http.go b/locale/http.go index 3112e8de7..c04838e2d 100644 --- a/locale/http.go +++ b/locale/http.go @@ -20,26 +20,28 @@ func (l Locale) Request(r *http.Request) *http.Request { if l.HasLanguage() { r.Header.Set(HeaderAcceptLanguage, l.acceptLanguage) } + if l.HasTimezone() { r.Header.Set(HeaderAcceptTimezone, l.acceptTimezone) } + return r } // Middleware takes the accept lang and timezone info and -// stores them in the context +// stores them in the context. type Middleware struct { next http.Handler } -// ServeHTTP adds the locale to the request context +// ServeHTTP adds the locale to the request context. func (m Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) r = r.WithContext(WithLocale(r.Context(), l)) m.next.ServeHTTP(w, r) } -// Handler builds new Middleware +// Handler builds new Middleware. func Handler() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return &Middleware{next: next} diff --git a/locale/http_test.go b/locale/http_test.go index ec0f527c1..a0fd53a00 100644 --- a/locale/http_test.go +++ b/locale/http_test.go @@ -12,7 +12,7 @@ import ( ) func TestEmptyRequest(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) l := FromRequest(r) @@ -21,7 +21,7 @@ func TestEmptyRequest(t *testing.T) { } func TestFilledRequest(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) r.Header.Set(HeaderAcceptLanguage, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5") r.Header.Set(HeaderAcceptTimezone, "Europe/Paris") @@ -33,7 +33,7 @@ func TestFilledRequest(t *testing.T) { func TestExtendRequestWithEmptyLocale(t *testing.T) { l := new(Locale) - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) data, err := httputil.DumpRequest(l.Request(r), false) @@ -44,7 +44,7 @@ func TestExtendRequestWithEmptyLocale(t *testing.T) { func TestExtendRequest(t *testing.T) { l := NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) data, err := httputil.DumpRequest(l.Request(r), false) @@ -64,7 +64,7 @@ func (m *httpRecorderNext) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func TestMiddlewareWithoutLocale(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) rec := new(httpRecorderNext) @@ -79,7 +79,7 @@ func TestMiddlewareWithoutLocale(t *testing.T) { func TestMiddlewareWithLocale(t *testing.T) { l := NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) rec := new(httpRecorderNext) diff --git a/locale/locale.go b/locale/locale.go index 180700045..bacbe3c88 100644 --- a/locale/locale.go +++ b/locale/locale.go @@ -20,12 +20,12 @@ import ( "time" ) -// None is no timezone or language +// None is no timezone or language. const None = "" var ErrNoTimezone = errors.New("no timezone given") -// Locale contains the preferred language and timezone of the request +// Locale contains the preferred language and timezone of the request. type Locale struct { // Language as per RFC 7231, section 5.3.5: Accept-Language // Example: "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5" @@ -36,7 +36,7 @@ type Locale struct { } // NewLocale creates a new locale based on the passed accepted and language -// and timezone +// and timezone. func NewLocale(acceptLanguage, acceptTimezone string) *Locale { return &Locale{ acceptLanguage: acceptLanguage, @@ -44,27 +44,27 @@ func NewLocale(acceptLanguage, acceptTimezone string) *Locale { } } -// Language of the locale +// Language of the locale. func (l Locale) Language() string { return l.acceptLanguage } -// HasTimezone returns true if the language is defined, false otherwise +// HasLanguage returns true if the language is defined, false otherwise. func (l Locale) HasLanguage() bool { return l.acceptLanguage != None } -// Timezone of the locale +// Timezone of the locale. func (l Locale) Timezone() string { return l.acceptTimezone } -// HasTimezone returns true if the timezone is defined, false otherwise +// HasTimezone returns true if the timezone is defined, false otherwise. func (l Locale) HasTimezone() bool { return l.acceptTimezone != None } -// Location based of the locale timezone +// Location based of the locale timezone. func (l Locale) Location() (*time.Location, error) { if !l.HasTimezone() { return nil, ErrNoTimezone @@ -75,30 +75,32 @@ func (l Locale) Location() (*time.Location, error) { const serializeSep = "|" -// Serialize into a transportable form +// Serialize into a transportable form. func (l Locale) Serialize() string { return l.acceptLanguage + serializeSep + l.acceptTimezone } // Now returns the current time with the set timezone -// or local time if timezone is not set +// or local time if timezone is not set. func (l Locale) Now() time.Time { if l.HasTimezone() { loc, err := l.Location() if err != nil { // if the tz doesn't exist return time.Now() } + return time.Now().In(loc) } return time.Now() // Local } -// ParseLocale parses a serialized locale +// ParseLocale parses a serialized locale. func ParseLocale(serialized string) (*Locale, error) { parts := strings.Split(serialized, serializeSep) if len(parts) != 2 { return nil, fmt.Errorf("invalid locale format: %q", serialized) } + return NewLocale(parts[0], parts[1]), nil } diff --git a/locale/locale_test.go b/locale/locale_test.go index 735d96e4d..2503ff882 100644 --- a/locale/locale_test.go +++ b/locale/locale_test.go @@ -44,6 +44,7 @@ func TestTimezone(t *testing.T) { loc, err := l.Location() assert.NoError(t, err) + timeInUTC := time.Date(2018, 8, 30, 12, 0, 0, 0, time.UTC) assert.Equal(t, "2018-08-30 14:00:00 +0200 CEST", timeInUTC.In(loc).String()) } @@ -58,6 +59,7 @@ func TestTimezoneAndLocale(t *testing.T) { loc, err := l.Location() assert.NoError(t, err) + timeInUTC := time.Date(2018, 8, 30, 12, 0, 0, 0, time.UTC) assert.Equal(t, "2018-08-30 14:00:00 +0200 CEST", timeInUTC.In(loc).String()) } diff --git a/locale/strategy.go b/locale/strategy.go index 59b3cbc46..30145dcc0 100644 --- a/locale/strategy.go +++ b/locale/strategy.go @@ -7,19 +7,20 @@ import ( "context" ) -// Strategy defines a function that returns a Locale based on the passed Context +// Strategy defines a function that returns a Locale based on the passed Context. type Strategy func(ctx context.Context) *Locale -// NewContextStrategy returns a strategy that defines a static fallback language and timezone. +// NewFallbackStrategy returns a strategy that defines a static fallback language and timezone. // If only lang or timezone fallback should be defined as a fallback, the None value may be used. func NewFallbackStrategy(lang, timezone string) Strategy { l := NewLocale(lang, timezone) + return func(ctx context.Context) *Locale { return l } } -// NewContextStrategy returns a strategy that takes the locale form the request +// NewContextStrategy returns a strategy that takes the locale form the request. func NewContextStrategy() Strategy { return func(ctx context.Context) *Locale { l, _ := FromCtx(ctx) @@ -28,30 +29,36 @@ func NewContextStrategy() Strategy { } // StrategyList has a list of strategies that are evaluated to find -// the correct user locale +// the correct user locale. type StrategyList struct { strategies list.List } -// PushBack inserts the passed strategies at the back of list +// PushBack inserts the passed strategies at the back of list. func (s *StrategyList) PushBack(strategies ...Strategy) { for _, strategy := range strategies { s.strategies.PushBack(strategy) } } -// PushFront inserts a passed strategies at the front of list +// PushFront inserts a passed strategies at the front of list. func (s *StrategyList) PushFront(strategies ...Strategy) { for _, strategy := range strategies { s.strategies.PushFront(strategy) } } -// Locale executes all strategies and returns the new locale +// Locale executes all strategies and returns the new locale. func (s *StrategyList) Locale(ctx context.Context) *Locale { var l Locale + for i := s.strategies.Front(); i != nil; i = i.Next() { - curLoc := (i.Value.(Strategy))(ctx) + strategy, ok := i.Value.(Strategy) + if !ok { + break + } + + curLoc := strategy(ctx) // take language if defined if !l.HasLanguage() && curLoc.HasLanguage() { @@ -68,13 +75,16 @@ func (s *StrategyList) Locale(ctx context.Context) *Locale { break } } + return &l } -// NewDefaultFallbackStrategy returns a strategy list configured via environment +// NewDefaultFallbackStrategy returns a strategy list configured via environment. func NewDefaultFallbackStrategy() *StrategyList { var sl StrategyList + sl.PushFront(NewFallbackStrategy(cfg.Language, cfg.Timezone)) sl.PushFront(NewContextStrategy()) + return &sl } diff --git a/locale/strategy_test.go b/locale/strategy_test.go index 52df4f032..d0039c409 100644 --- a/locale/strategy_test.go +++ b/locale/strategy_test.go @@ -18,6 +18,7 @@ func TestStrategy(t *testing.T) { func TestStrategyWithCtx(t *testing.T) { var sl StrategyList + sl.PushBack( NewContextStrategy(), NewFallbackStrategy("de-DE", "Europe/Berlin"), diff --git a/maintenance/errors/bricks.go b/maintenance/errors/bricks.go index 601ec1913..8c609cfec 100644 --- a/maintenance/errors/bricks.go +++ b/maintenance/errors/bricks.go @@ -13,7 +13,7 @@ import ( // BricksError - a bricks err is a bricks specific error which provides // convenience functions to be transformed into runtime.Errors (JSON errors) // pb generate can be used to create a set of pre defined BricksErrors based -// on a JSON specification, see pb generate for details +// on a JSON specification, see pb generate for details. type BricksError struct { // title - a short, human-readable summary of the problem that SHOULD NOT change from occurrence // to occurrence of the problem, except for purposes of localization. @@ -32,6 +32,7 @@ func NewBricksError(opts ...BricksErrorOption) *BricksError { for _, opt := range opts { opt(e) } + return e } @@ -56,7 +57,7 @@ func (e *BricksError) Status() int { } // AsRuntimeError - returns the BricksError as bricks runtime.Error which aligns -// with a JSON error object +// with a JSON error object. func (e *BricksError) AsRuntimeError() *runtime.Error { j := &runtime.Error{ ID: uuid.NewString(), @@ -65,6 +66,7 @@ func (e *BricksError) AsRuntimeError() *runtime.Error { Title: e.title, Detail: e.detail, } + return j } diff --git a/maintenance/errors/context.go b/maintenance/errors/context.go index 0d49fc6f8..2f92dc5bf 100644 --- a/maintenance/errors/context.go +++ b/maintenance/errors/context.go @@ -21,7 +21,7 @@ func Hide(ctx context.Context, err, exposedErr error) error { ret := err if exposedErr != nil { - ret = fmt.Errorf("%w: %s", exposedErr, err) + ret = fmt.Errorf("%w: %s", exposedErr, err.Error()) } if ctx.Err() == context.Canceled && errors.Is(err, context.Canceled) { @@ -41,5 +41,6 @@ func IsStdLibContextError(err error) bool { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true } + return false } diff --git a/maintenance/errors/context_test.go b/maintenance/errors/context_test.go index b5681d60d..acedaf4bd 100644 --- a/maintenance/errors/context_test.go +++ b/maintenance/errors/context_test.go @@ -27,6 +27,7 @@ func TestHide(t *testing.T) { err error exposedErr error } + tests := []struct { name string args args @@ -49,26 +50,26 @@ func TestHide(t *testing.T) { err: iAmAnError, exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError), + want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError.Error()), }, { name: "normal_context_with_error_nothing_exposed", args: args{ ctx: backgroundContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), exposedErr: nil, }, - want: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + want: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), expectContextErr: true, }, { name: "normal_context_with_error_with_exposed", args: args{ ctx: backgroundContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%w: %s: %s", iAmAnotherError, iAmAnError, context.Canceled), + want: fmt.Errorf("%w: %s: %s", iAmAnotherError, iAmAnError.Error(), context.Canceled.Error()), expectContextErr: true, }, { @@ -88,27 +89,27 @@ func TestHide(t *testing.T) { err: iAmAnError, exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError), + want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError.Error()), expectContextErr: false, }, { name: "canceled_context_with_error_nothing_exposed", args: args{ ctx: canceledContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), exposedErr: nil, }, - want: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + want: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), expectContextErr: true, }, { name: "canceled_context_with_error_with_exposed", args: args{ ctx: canceledContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.Canceled), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.Canceled), exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%s: %s: %w", iAmAnotherError, iAmAnError, context.Canceled), + want: fmt.Errorf("%s: %s: %w", iAmAnotherError.Error(), iAmAnError.Error(), context.Canceled), expectContextErr: true, }, { @@ -128,27 +129,27 @@ func TestHide(t *testing.T) { err: iAmAnError, exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError), + want: fmt.Errorf("%w: %s", iAmAnotherError, iAmAnError.Error()), expectContextErr: false, }, { name: "exceeded_context_with_error_nothing_exposed", args: args{ ctx: exceededContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.DeadlineExceeded), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.DeadlineExceeded), exposedErr: nil, }, - want: fmt.Errorf("%s: %w", iAmAnError, context.DeadlineExceeded), + want: fmt.Errorf("%s: %w", iAmAnError.Error(), context.DeadlineExceeded), expectContextErr: true, }, { name: "exceeded_context_with_error_with_exposed", args: args{ ctx: exceededContext, - err: fmt.Errorf("%s: %w", iAmAnError, context.DeadlineExceeded), + err: fmt.Errorf("%s: %w", iAmAnError.Error(), context.DeadlineExceeded), exposedErr: iAmAnotherError, }, - want: fmt.Errorf("%s: %s: %w", iAmAnotherError, iAmAnError, context.DeadlineExceeded), + want: fmt.Errorf("%s: %s: %w", iAmAnotherError.Error(), iAmAnError.Error(), context.DeadlineExceeded), expectContextErr: true, }, } diff --git a/maintenance/errors/error.go b/maintenance/errors/error.go index 77e6265b7..d2744255c 100644 --- a/maintenance/errors/error.go +++ b/maintenance/errors/error.go @@ -11,12 +11,13 @@ import ( "os" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" + "github.com/pace/bricks/http/jsonapi/runtime" "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/maintenance/errors/raven" "github.com/pace/bricks/maintenance/log" - "github.com/prometheus/client_golang/prometheus" - "github.com/rs/zerolog" ) var paceHTTPPanicCounter = prometheus.NewGauge(prometheus.GaugeOpts{ @@ -28,7 +29,7 @@ func init() { prometheus.MustRegister(paceHTTPPanicCounter) } -// PanicWrap wraps a panic for HandleRequest +// PanicWrap wraps a panic for HandleRequest. type PanicWrap struct { err interface{} } @@ -52,8 +53,11 @@ func contextWithRequest(ctx context.Context, r *http.Request) context.Context { func requestFromContext(ctx context.Context) *http.Request { if v := ctx.Value(reqKey); v != nil { - return v.(*http.Request) + if out, ok := v.(*http.Request); ok { + return out + } } + return nil } @@ -63,6 +67,7 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if r := requestFromContext(ctx); r != nil { return contextWithRequest(targetCtx, r) } + return targetCtx } @@ -78,7 +83,7 @@ func (h *contextHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.next.ServeHTTP(w, r.WithContext(contextWithRequest(r.Context(), r))) } -// Handler implements a panic recovering middleware +// Handler implements a panic recovering middleware. func Handler() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { next = &contextHandler{next: next} @@ -86,7 +91,7 @@ func Handler() func(http.Handler) http.Handler { } } -// HandleRequest should be called with defer to recover panics in request handlers +// HandleRequest should be called with defer to recover panics in request handlers. func HandleRequest(handlerName string, w http.ResponseWriter, r *http.Request) { if rp := recover(); rp != nil { paceHTTPPanicCounter.Inc() @@ -94,21 +99,24 @@ func HandleRequest(handlerName string, w http.ResponseWriter, r *http.Request) { } } -// HandleError reports the passed error to sentry +// HandleError reports the passed error to sentry. func HandleError(rp interface{}, handlerName string, w http.ResponseWriter, r *http.Request) { ctx := r.Context() + pw, ok := rp.(*PanicWrap) if ok { log.Ctx(ctx).Error().Str("handler", handlerName).Msgf("Panic: %v", pw.err) + rp = pw.err // unwrap error } else { log.Ctx(ctx).Error().Str("handler", handlerName).Msgf("Error: %v", rp) } + log.Stack(ctx) sentryEvent{ctx, r, rp, 1, handlerName}.Send() - runtime.WriteError(w, http.StatusInternalServerError, errors.New("Internal Server Error")) + runtime.WriteError(w, http.StatusInternalServerError, errors.New("internal Server Error")) } // Handle logs the given error and reports it to sentry. @@ -116,16 +124,18 @@ func Handle(ctx context.Context, rp interface{}) { pw, ok := rp.(*PanicWrap) if ok { log.Ctx(ctx).Error().Msgf("Panic: %v", pw.err) + rp = pw.err // unwrap error } else { log.Ctx(ctx).Error().Msgf("Error: %v", rp) } + log.Stack(ctx) sentryEvent{ctx, nil, rp, 1, ""}.Send() } -// HandleWithCtx should be called with defer to recover panics in goroutines +// HandleWithCtx should be called with defer to recover panics in goroutines. func HandleWithCtx(ctx context.Context, handlerName string) { if rp := recover(); rp != nil { log.Ctx(ctx).Error().Str("handler", handlerName).Msgf("Panic: %v", rp) @@ -144,7 +154,7 @@ func New(text string) error { return errors.New(text) } -// WrapWithExtra adds extra data to an error before reporting to Sentry +// WrapWithExtra adds extra data to an error before reporting to Sentry. func WrapWithExtra(err error, extraInfo map[string]interface{}) error { return raven.WrapWithExtra(err, extraInfo) } @@ -171,6 +181,7 @@ func (e sentryEvent) build() *raven.Packet { } rvalStr := fmt.Sprint(rp) + var packet *raven.Packet if err, ok := rp.(error); ok { @@ -190,10 +201,12 @@ func (e sentryEvent) build() *raven.Packet { // add user userID, ok := oauth2.UserID(ctx) + user := raven.User{ID: userID} if r != nil { user.IP = log.ProxyAwareRemote(r) } + packet.Interfaces = append(packet.Interfaces, &user) if ok { packet.Tags = append(packet.Tags, raven.Tag{Key: "user_id", Value: userID}) @@ -204,14 +217,17 @@ func (e sentryEvent) build() *raven.Packet { packet.Extra["req_id"] = reqID packet.Tags = append(packet.Tags, raven.Tag{Key: "req_id", Value: reqID}) } + if traceID := log.TraceIDFromContext(ctx); traceID != "" { packet.Extra["uber_trace_id"] = traceID packet.Tags = append(packet.Tags, raven.Tag{Key: "trace_id", Value: traceID}) } + packet.Extra["handler"] = handlerName if clientID, ok := oauth2.ClientID(ctx); ok { packet.Extra["oauth2_client_id"] = clientID } + if scopes := oauth2.Scopes(ctx); len(scopes) > 0 { packet.Extra["oauth2_scopes"] = scopes } @@ -247,6 +263,7 @@ func getBreadcrumbs(ctx context.Context) []*raven.Breadcrumb { } result := make([]*raven.Breadcrumb, len(data)) + for i, d := range data { crumb, err := createBreadcrumb(d) if err != nil { @@ -268,6 +285,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "time"`) } + delete(data, "time") time, err := time.Parse(time.RFC3339, timeRaw) @@ -279,6 +297,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "level"`) } + delete(data, "level") level, err := translateZerologLevelToSentryLevel(levelRaw) @@ -290,12 +309,14 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "message"`) } + delete(data, "message") categoryRaw, ok := data["sentry:category"] if !ok { categoryRaw = "" } + delete(data, "sentry:category") category, ok := categoryRaw.(string) @@ -307,6 +328,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { typRaw = "" } + delete(data, "sentry:type") typ, ok := typRaw.(string) diff --git a/maintenance/errors/error_test.go b/maintenance/errors/error_test.go index dc56c2009..29a046aa3 100644 --- a/maintenance/errors/error_test.go +++ b/maintenance/errors/error_test.go @@ -30,13 +30,14 @@ func TestHandler(t *testing.T) { }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) mux.ServeHTTP(rec, req) if rec.Code != 500 { t.Errorf("Expected 500, got %d", rec.Code) } + if strings.Contains(rec.Body.String(), `"error":"Error"`) { t.Errorf(`Expected "error":"Error", got %q`, rec.Body.String()) } @@ -108,9 +109,9 @@ func Test_createBreadcrumb(t *testing.T) { "sentry:category": "http", "sentry:type": "http", "message": "HTTPS GET www.pace.car", - "method": "GET", + "method": http.MethodGet, "attempt": 1, - "status_code": 200, + "status_code": http.StatusOK, "duration": 227.717783, "url": "https://www.pace.car/", "req_id": "bpboj6bipt34r4teo7g0", @@ -122,9 +123,9 @@ func Test_createBreadcrumb(t *testing.T) { Timestamp: 1582795168, Type: "http", Data: map[string]interface{}{ - "method": "GET", + "method": http.MethodGet, "attempt": 1, - "status_code": 200, + "status_code": http.StatusOK, "duration": 227.717783, "url": "https://www.pace.car/", }, @@ -183,10 +184,10 @@ func Test_createBreadcrumb(t *testing.T) { // which should be passed to all subsequent requests and handler. func TestHandlerWithLogSink(t *testing.T) { rec1 := httptest.NewRecorder() - req1 := httptest.NewRequest("GET", "/test1", nil) + req1 := httptest.NewRequest(http.MethodGet, "/test1", nil) rec2 := httptest.NewRecorder() - req2 := httptest.NewRequest("GET", "/test2", nil) + req2 := httptest.NewRequest(http.MethodGet, "/test2", nil) var ( sink1Ctx context.Context @@ -196,29 +197,38 @@ func TestHandlerWithLogSink(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/test1", func(w http.ResponseWriter, r *http.Request) { sink1Ctx = r.Context() + log.Ctx(r.Context()).Debug().Msg("ONLY FOR SINK1") w.WriteHeader(http.StatusOK) }) mux.HandleFunc("/test2", func(w http.ResponseWriter, r *http.Request) { require.NotEqual(t, "", log.RequestID(r), "request should have request id") + sink2Ctx = r.Context() client := &http.Client{ Transport: transport.Chain(&transport.LoggingRoundTripper{}, &transport.DumpRoundTripper{}), } - r0, err := http.NewRequest("GET", "https://www.pace.car/de", nil) + r0, err := http.NewRequest(http.MethodGet, "https://www.pace.car/de", nil) assert.NoError(t, err, `failed creating request to "/succeed"`) r0 = r0.WithContext(r.Context()) - _, err = client.Do(r0) + + resp, err := client.Do(r0) assert.NoError(t, err, `request to "/succeed" should not error`) - r1, err := http.NewRequest("GET", "http://localhost/fail", nil) + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + + r1, err := http.NewRequest(http.MethodGet, "http://localhost/fail", nil) assert.NoError(t, err, `failed creating request to "/fail"`) r1 = r1.WithContext(r.Context()) - _, err = client.Do(r1) + + _, err = client.Do(r1) //nolint:bodyclose assert.Error(t, err, `request to "/fail" should error`) log.Req(r).Info(). @@ -228,22 +238,30 @@ func TestHandlerWithLogSink(t *testing.T) { panic("Sink2 Test Error, IGNORE") }) + handler := log.Handler()(Handler()(mux)) handler.ServeHTTP(rec1, req1) + resp1 := rec1.Result() require.Equal(t, http.StatusOK, resp1.StatusCode, "wrong status code") - resp1.Body.Close() + + err := resp1.Body.Close() + assert.NoError(t, err) handler.ServeHTTP(rec2, req2) + resp2 := rec2.Result() require.Equal(t, http.StatusInternalServerError, resp2.StatusCode, "wrong status code") - resp2.Body.Close() + + err = resp2.Body.Close() + assert.NoError(t, err) sink1, ok := log.SinkFromContext(sink1Ctx) assert.True(t, ok, "failed getting sink1") var sink1LogLines []json.RawMessage + assert.NoError(t, json.Unmarshal(sink1.ToJSON(), &sink1LogLines), "failed extracting logs from sink1") assert.Len(t, sink1LogLines, 2, "more log lines than expected") @@ -253,6 +271,7 @@ func TestHandlerWithLogSink(t *testing.T) { assert.True(t, ok, "failed getting sink2") var sink2LogLines []json.RawMessage + assert.NoError(t, json.Unmarshal(sink2.ToJSON(), &sink2LogLines), "failed extracting logs from sink2") assert.NotContains(t, string(sink2LogLines[0]), "ONLY FOR SINK1", "unexpected log line found") diff --git a/maintenance/errors/raven/client.go b/maintenance/errors/raven/client.go index e5f26fe24..544a28159 100644 --- a/maintenance/errors/raven/client.go +++ b/maintenance/errors/raven/client.go @@ -23,9 +23,10 @@ import ( "time" "github.com/certifi/gocertifi" + pkgErrors "github.com/pkg/errors" + "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/redact" - pkgErrors "github.com/pkg/errors" ) const ( @@ -65,6 +66,7 @@ func (timestamp *Timestamp) UnmarshalJSON(data []byte) error { } *timestamp = Timestamp(t) + return nil } @@ -111,7 +113,9 @@ func (t *Tag) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &tag); err != nil { return err } + *t = Tag{tag[0], tag[1]} + return nil } @@ -140,6 +144,7 @@ func (t *Tags) UnmarshalJSON(data []byte) error { } *t = tags + return nil } @@ -184,6 +189,7 @@ type Breadcrumb struct { func NewPacket(message string, interfaces ...Interface) *Packet { extra := Extra{} setExtraDefaults(extra) + return &Packet{ Message: message, Interfaces: interfaces, @@ -196,6 +202,7 @@ func NewPacketWithExtra(message string, extra Extra, interfaces ...Interface) *P if extra == nil { extra = Extra{} } + setExtraDefaults(extra) return &Packet{ @@ -210,6 +217,7 @@ func setExtraDefaults(extra Extra) Extra { extra["runtime.NumCPU"] = runtime.NumCPU() extra["runtime.GOMAXPROCS"] = runtime.GOMAXPROCS(0) // 0 just returns the current value extra["runtime.NumGoroutine"] = runtime.NumGoroutine() + return extra } @@ -219,25 +227,32 @@ func (packet *Packet) Init(project string) error { if packet.Project == "" { packet.Project = project } + if packet.EventID == "" { var err error + packet.EventID, err = uuid() if err != nil { return err } } + if time.Time(packet.Timestamp).IsZero() { packet.Timestamp = Timestamp(time.Now()) } + if packet.Level == "" { packet.Level = ERROR } + if packet.Logger == "" { packet.Logger = "root" } + if packet.ServerName == "" { packet.ServerName = hostname } + if packet.Platform == "" { packet.Platform = "go" } @@ -264,14 +279,17 @@ func (packet *Packet) AddTags(tags map[string]string) { func uuid() (string, error) { id := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, id) if err != nil { return "", err } + id[6] &= 0x0F // clear version id[6] |= 0x40 // set version to 4 (random uuid) id[8] &= 0x3F // clear variant id[8] |= 0x80 // set to IETF variant + return hex.EncodeToString(id), nil } @@ -282,6 +300,7 @@ func (packet *Packet) JSON() ([]byte, error) { } interfaces := make(map[string]Interface, len(packet.Interfaces)) + for _, inter := range packet.Interfaces { if inter != nil { interfaces[inter.Class()] = inter @@ -293,6 +312,7 @@ func (packet *Packet) JSON() ([]byte, error) { if err != nil { return nil, err } + packetJSON[len(packetJSON)-1] = ',' packetJSON = append(packetJSON, interfaceJSON[1:]...) } @@ -312,6 +332,7 @@ func (c *context) setTags(t map[string]string) { if c.tags == nil { c.tags = make(map[string]string) } + for k, v := range t { c.tags[k] = v } @@ -323,23 +344,27 @@ func (c *context) clear() { c.tags = nil } -// Return a list of interfaces to be used in appending with the rest +// Return a list of interfaces to be used in appending with the rest. func (c *context) interfaces() []Interface { len, i := 0, 0 if c.user != nil { len++ } + if c.http != nil { len++ } + interfaces := make([]Interface, len) if c.user != nil { interfaces[i] = c.user i++ } + if c.http != nil { interfaces[i] = c.http } + return interfaces } @@ -349,6 +374,7 @@ var MaxQueueBuffer = 100 func newTransport() Transport { t := &HTTPTransport{} + rootCAs, err := gocertifi.CACerts() if err != nil { log.Println("raven: failed to load root TLS certificates:", err) @@ -360,6 +386,7 @@ func newTransport() Transport { }, } } + return t } @@ -372,16 +399,19 @@ func newClient(tags map[string]string) *Client { queue: make(chan *outgoingPacket, MaxQueueBuffer), } dsn := os.Getenv("SENTRY_DSN") + err := client.SetDSN(dsn) if err != nil && dsn != "" { log.Warnf("DSN environment was set to %q but failed: %v", dsn, err) } + client.SetRelease(os.Getenv("SENTRY_RELEASE")) client.SetEnvironment(os.Getenv("ENVIRONMENT")) + return client } -// New constructs a new Sentry client instance +// New constructs a new Sentry client instance. func New(dsn string) (*Client, error) { client := newClient(nil) return client, client.SetDSN(dsn) @@ -396,7 +426,7 @@ func NewWithTags(dsn string, tags map[string]string) (*Client, error) { // NewClient constructs a Sentry client and spawns a background goroutine to // handle packets sent by Client.Report. // -// Deprecated: use New and NewWithTags instead +// Deprecated: use New and NewWithTags instead. func NewClient(dsn string, tags map[string]string) (*Client, error) { client := newClient(tags) return client, client.SetDSN(dsn) @@ -440,11 +470,12 @@ type Client struct { start sync.Once } -// Initialize a default *Client instance +// Initialize a default *Client instance. var DefaultClient = newClient(nil) func (c *Client) SetIgnoreErrors(errs []string) error { joinedRegexp := strings.Join(errs, "|") + r, err := regexp.Compile(joinedRegexp) if err != nil { return fmt.Errorf("failed to compile regexp %q for %q: %v", joinedRegexp, errs, err) @@ -453,12 +484,14 @@ func (c *Client) SetIgnoreErrors(errs []string) error { c.mu.Lock() c.ignoreErrorsRegexp = r c.mu.Unlock() + return nil } func (c *Client) shouldExcludeErr(errStr string) bool { c.mu.RLock() defer c.mu.RUnlock() + return c.ignoreErrorsRegexp != nil && c.ignoreErrorsRegexp.MatchString(errStr) } @@ -484,6 +517,7 @@ func (client *Client) SetDSN(dsn string) error { if uri.User == nil { return ErrMissingUser } + publicKey := uri.User.Username() secretKey, hasSecretKey := uri.User.Password() uri.User = nil @@ -492,6 +526,7 @@ func (client *Client) SetDSN(dsn string) error { client.projectID = uri.Path[idx+1:] uri.Path = uri.Path[:idx+1] + "api/" + client.projectID + "/store/" } + if client.projectID == "" { return ErrMissingProjectID } @@ -507,13 +542,14 @@ func (client *Client) SetDSN(dsn string) error { return nil } -// Sets the DSN for the default *Client instance +// Sets the DSN for the default *Client instance. func SetDSN(dsn string) error { return DefaultClient.SetDSN(dsn) } // SetRelease sets the "release" tag. func (client *Client) SetRelease(release string) { client.mu.Lock() defer client.mu.Unlock() + client.release = release } @@ -521,6 +557,7 @@ func (client *Client) SetRelease(release string) { func (client *Client) SetEnvironment(environment string) { client.mu.Lock() defer client.mu.Unlock() + client.environment = environment } @@ -528,10 +565,11 @@ func (client *Client) SetEnvironment(environment string) { func (client *Client) SetDefaultLoggerName(name string) { client.mu.Lock() defer client.mu.Unlock() + client.defaultLoggerName = name } -// SetSampleRate sets how much sampling we want on client side +// SetSampleRate sets how much sampling we want on client side. func (client *Client) SetSampleRate(rate float32) error { client.mu.Lock() defer client.mu.Unlock() @@ -539,32 +577,34 @@ func (client *Client) SetSampleRate(rate float32) error { if rate < 0 || rate > 1 { return ErrInvalidSampleRate } + client.sampleRate = rate + return nil } -// SetRelease sets the "release" tag on the default *Client +// SetRelease sets the "release" tag on the default *Client. func SetRelease(release string) { DefaultClient.SetRelease(release) } -// SetEnvironment sets the "environment" tag on the default *Client +// SetEnvironment sets the "environment" tag on the default *Client. func SetEnvironment(environment string) { DefaultClient.SetEnvironment(environment) } -// SetDefaultLoggerName sets the "defaultLoggerName" on the default *Client +// SetDefaultLoggerName sets the "defaultLoggerName" on the default *Client. func SetDefaultLoggerName(name string) { DefaultClient.SetDefaultLoggerName(name) } -// SetSampleRate sets the "sample rate" on the degault *Client +// SetSampleRate sets the "sample rate" on the degault *Client. func SetSampleRate(rate float32) error { return DefaultClient.SetSampleRate(rate) } func (client *Client) worker() { for outgoingPacket := range client.queue { - client.mu.RLock() url, authHeader := client.url, client.authHeader client.mu.RUnlock() outgoingPacket.ch <- client.Transport.Send(url, authHeader, outgoingPacket.packet) + client.wg.Done() } } @@ -606,6 +646,7 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev // Initialize any required packet fields client.mu.RLock() packet.AddTags(client.context.tags) + projectID := client.projectID release := client.release environment := client.environment @@ -617,10 +658,11 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev packet.Logger = defaultLoggerName } - err := packet.Init(projectID) - if err != nil { + if err := packet.Init(projectID); err != nil { ch <- err + client.wg.Done() + return } @@ -648,6 +690,7 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev client.DropHandler(packet) } ch <- ErrPacketDropped + client.wg.Done() } @@ -677,7 +720,7 @@ func (client *Client) CaptureMessage(message string, tags map[string]string, int return eventID } -// CaptureMessage formats and delivers a string message to the Sentry server with the default *Client +// CaptureMessage formats and delivers a string message to the Sentry server with the default *Client. func CaptureMessage(message string, tags map[string]string, interfaces ...Interface) string { return DefaultClient.CaptureMessage(message, tags, interfaces...) } @@ -693,6 +736,7 @@ func (client *Client) CaptureMessageAndWait(message string, tags map[string]stri } packet := NewPacket(message, append(append(interfaces, client.context.interfaces()...), &Message{message, nil})...) + eventID, ch := client.Capture(packet, tags) if eventID != "" { <-ch @@ -736,7 +780,7 @@ func CaptureError(err error, tags map[string]string, interfaces ...Interface) st return DefaultClient.CaptureError(err, tags, interfaces...) } -// CaptureErrorAndWait is identical to CaptureError, except it blocks and assures that the event was sent +// CaptureErrorAndWait is identical to CaptureError, except it blocks and assures that the event was sent. func (client *Client) CaptureErrorAndWait(err error, tags map[string]string, interfaces ...Interface) string { if client == nil { return "" @@ -750,6 +794,7 @@ func (client *Client) CaptureErrorAndWait(err error, tags map[string]string, int cause := pkgErrors.Cause(err) packet := NewPacketWithExtra(err.Error(), extra, append(append(interfaces, client.context.interfaces()...), NewException(cause, GetOrNewStacktrace(cause, 1, 3, client.includePaths)))...) + eventID, ch := client.Capture(packet, tags) if eventID != "" { <-ch @@ -758,7 +803,7 @@ func (client *Client) CaptureErrorAndWait(err error, tags map[string]string, int return eventID } -// CaptureErrorAndWait is identical to CaptureError, except it blocks and assures that the event was sent +// CaptureErrorAndWait is identical to CaptureError, except it blocks and assures that the event was sent. func CaptureErrorAndWait(err error, tags map[string]string, interfaces ...Interface) string { return DefaultClient.CaptureErrorAndWait(err, tags, interfaces...) } @@ -772,6 +817,7 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces // be completely noop though if we cared. defer func() { var packet *Packet + err = recover() switch rval := err.(type) { case nil: @@ -780,12 +826,14 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces if client.shouldExcludeErr(rval.Error()) { return } + packet = NewPacket(rval.Error(), append(append(interfaces, client.context.interfaces()...), NewException(rval, NewStacktrace(2, 3, client.includePaths)))...) default: rvalStr := fmt.Sprint(rval) if client.shouldExcludeErr(rvalStr) { return } + packet = NewPacket(rvalStr, append(append(interfaces, client.context.interfaces()...), NewException(errors.New(rvalStr), NewStacktrace(2, 3, client.includePaths)))...) } @@ -793,6 +841,7 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces }() f() + return } @@ -802,7 +851,7 @@ func CapturePanic(f func(), tags map[string]string, interfaces ...Interface) (in return DefaultClient.CapturePanic(f, tags, interfaces...) } -// CapturePanicAndWait is identical to CaptureError, except it blocks and assures that the event was sent +// CapturePanicAndWait is identical to CaptureError, except it blocks and assures that the event was sent. func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, interfaces ...Interface) (err interface{}, errorID string) { // Note: This doesn't need to check for client, because we still want to go through the defer/recover path // Down the line, Capture will be noop'd, so while this does a _tiny_ bit of overhead constructing the @@ -810,6 +859,7 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte // be completely noop though if we cared. defer func() { var packet *Packet + err = recover() switch rval := err.(type) { case nil: @@ -818,16 +868,19 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte if client.shouldExcludeErr(rval.Error()) { return } + packet = NewPacket(rval.Error(), append(append(interfaces, client.context.interfaces()...), NewException(rval, NewStacktrace(2, 3, client.includePaths)))...) default: rvalStr := fmt.Sprint(rval) if client.shouldExcludeErr(rvalStr) { return } + packet = NewPacket(rvalStr, append(append(interfaces, client.context.interfaces()...), NewException(errors.New(rvalStr), NewStacktrace(2, 3, client.includePaths)))...) } var ch chan error + errorID, ch = client.Capture(packet, tags) if errorID != "" { <-ch @@ -835,10 +888,11 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte }() f() + return } -// CapturePanicAndWait is identical to CaptureError, except it blocks and assures that the event was sent +// CapturePanicAndWait is identical to CaptureError, except it blocks and assures that the event was sent. func CapturePanicAndWait(f func(), tags map[string]string, interfaces ...Interface) (interface{}, string) { return DefaultClient.CapturePanicAndWait(f, tags, interfaces...) } @@ -849,12 +903,12 @@ func (client *Client) Close() { func Close() { DefaultClient.Close() } -// Wait blocks and waits for all events to finish being sent to Sentry server +// Wait blocks and waits for all events to finish being sent to Sentry server. func (client *Client) Wait() { client.wg.Wait() } -// Wait blocks and waits for all events to finish being sent to Sentry server +// Wait blocks and waits for all events to finish being sent to Sentry server. func Wait() { DefaultClient.Wait() } func (client *Client) URL() string { @@ -946,22 +1000,28 @@ func (t *HTTPTransport) Send(url, authHeader string, packet *Packet) error { if err != nil { return fmt.Errorf("error serializing packet: %v", err) } - req, err := http.NewRequest("POST", url, body) + + req, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return fmt.Errorf("can't create new request: %v", err) } + req.Header.Set("X-Sentry-Auth", authHeader) req.Header.Set("User-Agent", userAgent) req.Header.Set("Content-Type", contentType) + res, err := t.Do(req) if err != nil { return err } - io.Copy(io.Discard, res.Body) // nolint: errcheck + + io.Copy(io.Discard, res.Body) res.Body.Close() - if res.StatusCode != 200 { + + if res.StatusCode != http.StatusOK { return fmt.Errorf("raven: got http status %d", res.StatusCode) } + return nil } @@ -980,11 +1040,13 @@ func serializedPacket(packet *Packet) (io.Reader, string, error) { buf := &bytes.Buffer{} b64 := base64.NewEncoder(base64.StdEncoding, buf) deflate, _ := zlib.NewWriterLevel(b64, zlib.BestCompression) - deflate.Write(packetJSON) // nolint: errcheck + deflate.Write(packetJSON) deflate.Close() b64.Close() + return buf, "application/octet-stream", nil } + return bytes.NewReader(packetJSON), "application/json", nil } diff --git a/maintenance/errors/raven/errors.go b/maintenance/errors/raven/errors.go index 5e5727043..8c5c02e4d 100644 --- a/maintenance/errors/raven/errors.go +++ b/maintenance/errors/raven/errors.go @@ -21,7 +21,7 @@ func (ewx *errWrappedWithExtra) ExtraInfo() Extra { return ewx.extraInfo } -// Adds extra data to an error before reporting to Sentry +// Adds extra data to an error before reporting to Sentry. func WrapWithExtra(err error, extraInfo map[string]interface{}) error { return &errWrappedWithExtra{ err: err, diff --git a/maintenance/errors/raven/exception.go b/maintenance/errors/raven/exception.go index 552eaad12..f47e96684 100644 --- a/maintenance/errors/raven/exception.go +++ b/maintenance/errors/raven/exception.go @@ -14,9 +14,11 @@ func NewException(err error, stacktrace *Stacktrace) *Exception { Value: msg, Type: reflect.TypeOf(err).String(), } + if m := errorMsgPattern.FindStringSubmatch(msg); m != nil { ex.Module, ex.Value = m[1], m[2] } + return ex } @@ -37,6 +39,7 @@ func (e *Exception) Culprit() string { if e.Stacktrace == nil { return "" } + return e.Stacktrace.Culprit() } diff --git a/maintenance/errors/raven/http.go b/maintenance/errors/raven/http.go index 0d8fbb112..3d2b76262 100644 --- a/maintenance/errors/raven/http.go +++ b/maintenance/errors/raven/http.go @@ -15,6 +15,7 @@ func NewHttp(req *http.Request) *Http { if req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https" { proto = "https" } + h := &Http{ Method: req.Method, Cookies: req.Header.Get("Cookie"), @@ -25,10 +26,13 @@ func NewHttp(req *http.Request) *Http { if addr, port, err := net.SplitHostPort(req.RemoteAddr); err == nil { h.Env = map[string]string{"REMOTE_ADDR": addr, "REMOTE_PORT": port} } + for k, v := range req.Header { h.Headers[k] = strings.Join(v, ",") } + h.Headers["Host"] = req.Host + return h } @@ -42,6 +46,7 @@ func sanitizeQuery(query url.Values) url.Values { } } } + return query } @@ -74,13 +79,17 @@ func RecoveryHandler(handler func(http.ResponseWriter, *http.Request)) func(http defer func() { if rval := recover(); rval != nil { debug.PrintStack() + rvalStr := fmt.Sprint(rval) + var packet *Packet + if err, ok := rval.(error); ok { packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), GetOrNewStacktrace(err, 2, 3, nil)), NewHttp(r)) } else { packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), NewStacktrace(2, 3, nil)), NewHttp(r)) } + Capture(packet, nil) w.WriteHeader(http.StatusInternalServerError) } diff --git a/maintenance/errors/raven/stacktrace.go b/maintenance/errors/raven/stacktrace.go index aab2a3429..3db53c1fd 100644 --- a/maintenance/errors/raven/stacktrace.go +++ b/maintenance/errors/raven/stacktrace.go @@ -32,6 +32,7 @@ func (s *Stacktrace) Culprit() string { return frame.Module + "." + frame.Function } } + return "" } @@ -51,28 +52,33 @@ type StacktraceFrame struct { InApp bool `json:"in_app"` } -// Try to get stacktrace from err as an interface of github.com/pkg/errors, or else NewStacktrace() +// Try to get stacktrace from err as an interface of github.com/pkg/errors, or else NewStacktrace(). func GetOrNewStacktrace(err error, skip int, context int, appPackagePrefixes []string) *Stacktrace { stacktracer, errHasStacktrace := err.(interface { StackTrace() errors.StackTrace }) if errHasStacktrace { var frames []*StacktraceFrame + for _, f := range stacktracer.StackTrace() { pc := uintptr(f) - 1 fn := runtime.FuncForPC(pc) + var file string + var line int if fn != nil { file, line = fn.FileLine(pc) } else { file = "unknown" } + frame := NewStacktraceFrame(pc, file, line, context, appPackagePrefixes) if frame != nil { frames = append([]*StacktraceFrame{frame}, frames...) } } + return &Stacktrace{Frames: frames} } else { return NewStacktrace(skip+1, context, appPackagePrefixes) @@ -89,11 +95,13 @@ func GetOrNewStacktrace(err error, skip int, context int, appPackagePrefixes []s // be considered "in app". func NewStacktrace(skip int, context int, appPackagePrefixes []string) *Stacktrace { var frames []*StacktraceFrame + for i := 1 + skip; ; i++ { pc, file, line, ok := runtime.Caller(i) if !ok { break } + frame := NewStacktraceFrame(pc, file, line, context, appPackagePrefixes) if frame != nil { frames = append(frames, frame) @@ -111,6 +119,7 @@ func NewStacktrace(skip int, context int, appPackagePrefixes []string) *Stacktra for i, j := 0, len(frames)-1; i < j; i, j = i+1, j-1 { frames[i], frames[j] = frames[j], frames[i] } + return &Stacktrace{frames} } @@ -162,6 +171,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr frame.ContextLine = string(contextLine[0]) } } + return frame } @@ -199,6 +209,7 @@ var ( func fileContext(filename string, line, context int) ([][]byte, int) { fileCacheLock.Lock() defer fileCacheLock.Unlock() + lines, ok := fileCache[filename] if !ok { data, err := os.ReadFile(filename) @@ -209,6 +220,7 @@ func fileContext(filename string, line, context int) ([][]byte, int) { fileCache[filename] = nil return nil, 0 } + lines = bytes.Split(data, []byte{'\n'}) fileCache[filename] = lines } @@ -220,32 +232,39 @@ func fileContext(filename string, line, context int) ([][]byte, int) { line-- // stack trace lines are 1-indexed start := line - context + var idx int + if start < 0 { start = 0 idx = line } else { idx = context } + end := line + context + 1 + if line >= len(lines) { return nil, 0 } + if end > len(lines) { end = len(lines) } + return lines[start:end], idx } var trimPaths []string -// Try to trim the GOROOT or GOPATH prefix off of a filename +// Try to trim the GOROOT or GOPATH prefix off of a filename. func trimPath(filename string) string { for _, prefix := range trimPaths { if trimmed := strings.TrimPrefix(filename, prefix); len(trimmed) < len(filename) { return trimmed } } + return filename } @@ -256,6 +275,7 @@ func init() { if prefix[len(prefix)-1] != filepath.Separator { prefix += string(filepath.Separator) } + trimPaths = append(trimPaths, prefix) } } diff --git a/maintenance/failover/failover.go b/maintenance/failover/failover.go index 82df646cb..cef699040 100644 --- a/maintenance/failover/failover.go +++ b/maintenance/failover/failover.go @@ -11,11 +11,12 @@ import ( "time" "github.com/bsm/redislock" + "github.com/redis/go-redis/v9" + "github.com/pace/bricks/backend/k8sapi" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health" "github.com/pace/bricks/maintenance/log" - "github.com/redis/go-redis/v9" ) type status int @@ -158,11 +159,15 @@ func (a *ActivePassive) Stop() { a.close <- struct{}{} } -// Handler implements the readiness http endpoint +// Handler implements the readiness http endpoint. func (a *ActivePassive) Handler(w http.ResponseWriter, r *http.Request) { label := a.label(a.getState()) + w.WriteHeader(http.StatusOK) - fmt.Fprintln(w, strings.ToUpper(label)) + + if _, err := fmt.Fprintln(w, strings.ToUpper(label)); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } } func (a *ActivePassive) label(s status) string { @@ -196,19 +201,21 @@ func (a *ActivePassive) becomeUndefined(ctx context.Context) { a.setState(ctx, UNDEFINED) } -// setState returns true if the state was set successfully +// setState returns true if the state was set successfully. func (a *ActivePassive) setState(ctx context.Context, state status) bool { - err := a.k8sClient.SetCurrentPodLabel(ctx, Label, a.label(state)) - if err != nil { + if err := a.k8sClient.SetCurrentPodLabel(ctx, Label, a.label(state)); err != nil { log.Ctx(ctx).Error().Err(err).Msg("failed to mark pod as undefined") a.stateMu.Lock() a.state = UNDEFINED a.stateMu.Unlock() + return false } + a.stateMu.Lock() a.state = state a.stateMu.Unlock() + return true } @@ -216,5 +223,6 @@ func (a *ActivePassive) getState() status { a.stateMu.RLock() state := a.state a.stateMu.RUnlock() + return state } diff --git a/maintenance/health/health.go b/maintenance/health/health.go index 495dee15c..4ab0045fe 100644 --- a/maintenance/health/health.go +++ b/maintenance/health/health.go @@ -30,8 +30,8 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { )) } -// ReadinessCheck allows to set a different function for the readiness check. The default readiness check -// is the same as the liveness check and does always return OK +// SetCustomReadinessCheck allows to set a different function for the readiness check. The default readiness check +// is the same as the liveness check and does always return OK. func SetCustomReadinessCheck(check func(http.ResponseWriter, *http.Request)) { readinessCheck.check = check } @@ -39,18 +39,19 @@ func SetCustomReadinessCheck(check func(http.ResponseWriter, *http.Request)) { func liveness(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) + if _, err := fmt.Fprint(w, "OK\n"); err != nil { log.Warnf("could not write output: %s", err) } } -// HandlerLiveness returns the liveness handler that always return OK and 200 +// HandlerLiveness returns the liveness handler that always return OK and 200. func HandlerLiveness() http.Handler { return &handler{check: liveness} } // HandlerReadiness returns the readiness handler. This handler can be configured with -// ReadinessCheck(func(http.ResponseWriter,*http.Request)), the default behavior is a liveness check +// ReadinessCheck(func(http.ResponseWriter,*http.Request)), the default behavior is a liveness check. func HandlerReadiness() http.Handler { return readinessCheck } diff --git a/maintenance/health/health_test.go b/maintenance/health/health_test.go index 8947bd1b3..4aed4b5fe 100644 --- a/maintenance/health/health_test.go +++ b/maintenance/health/health_test.go @@ -8,31 +8,35 @@ import ( "net/http/httptest" "testing" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestHandlerLiveness(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/liveness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/liveness", nil) HandlerLiveness().ServeHTTP(rec, req) - checkResult(rec, 200, "OK\n", t) + checkResult(rec, http.StatusOK, "OK\n", t) } func TestHandlerReadiness(t *testing.T) { // check the default rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/readiness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/readiness", nil) HandlerReadiness().ServeHTTP(rec, req) // check another readiness check - checkResult(rec, 200, "OK\n", t) + checkResult(rec, http.StatusOK, "OK\n", t) + rec = httptest.NewRecorder() + SetCustomReadinessCheck(func(w http.ResponseWriter, request *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusNotFound) + if _, err := w.Write([]byte("Err\n")); err != nil { log.Warnf("could not write output: %s", err) } @@ -43,10 +47,16 @@ func TestHandlerReadiness(t *testing.T) { func checkResult(rec *httptest.ResponseRecorder, expCode int, expBody string, t *testing.T) { resp := rec.Result() - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + require.NoError(t, err) + }() + if resp.StatusCode != expCode { t.Errorf("Expected /health to respond with %d, got: %d", expCode, resp.StatusCode) } + data, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expBody, string(data)) diff --git a/maintenance/health/servicehealthcheck/config.go b/maintenance/health/servicehealthcheck/config.go index 65e9a44ff..85efda966 100644 --- a/maintenance/health/servicehealthcheck/config.go +++ b/maintenance/health/servicehealthcheck/config.go @@ -24,8 +24,7 @@ type config struct { var cfg config func init() { - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse health check environment: %v", err) } } @@ -58,7 +57,7 @@ func UseMaxWait(maxWait time.Duration) HealthCheckOption { } } -// UseWarmup - delays a healthcheck during warmup +// UseWarmup - delays a healthcheck during warmup. func UseWarmup(delay time.Duration) HealthCheckOption { return func(cfg *HealthCheckCfg) { cfg.warmupDelay = delay diff --git a/maintenance/health/servicehealthcheck/connection_state.go b/maintenance/health/servicehealthcheck/connection_state.go index f9b485587..504029664 100644 --- a/maintenance/health/servicehealthcheck/connection_state.go +++ b/maintenance/health/servicehealthcheck/connection_state.go @@ -17,6 +17,7 @@ type ConnectionState struct { func (cs *ConnectionState) setConnectionState(result HealthCheckResult) { cs.m.Lock() defer cs.m.Unlock() + cs.result = result cs.lastCheck = time.Now() } @@ -36,6 +37,7 @@ func (cs *ConnectionState) SetHealthy() { func (cs *ConnectionState) GetState() HealthCheckResult { cs.m.Lock() defer cs.m.Unlock() + return cs.result } @@ -43,5 +45,6 @@ func (cs *ConnectionState) GetState() HealthCheckResult { func (cs *ConnectionState) LastChecked() time.Time { cs.m.Lock() defer cs.m.Unlock() + return cs.lastCheck } diff --git a/maintenance/health/servicehealthcheck/health_handler.go b/maintenance/health/servicehealthcheck/health_handler.go index 45e84ae01..2f19b886a 100644 --- a/maintenance/health/servicehealthcheck/health_handler.go +++ b/maintenance/health/servicehealthcheck/health_handler.go @@ -14,7 +14,9 @@ import ( func HealthHandler() http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { var errors []string + var warnings []string + for name, res := range checksResults(&requiredChecks) { if res.State == Err { errors = append(errors, fmt.Sprintf("%s: %s", name, res.Msg)) @@ -22,12 +24,16 @@ func HealthHandler() http.HandlerFunc { warnings = append(warnings, fmt.Sprintf("%s: %s", name, res.Msg)) } } + if len(errors) > 0 { log.Logger().Info().Strs("errors", errors).Strs("warnings", warnings).Msg("Health check failed") + msg := fmt.Sprintf("ERR: %d errors and %d warnings", len(errors), len(warnings)) writeResult(w, http.StatusServiceUnavailable, msg) + return } + writeResult(w, http.StatusOK, string(Ok)) } } @@ -35,6 +41,7 @@ func HealthHandler() http.HandlerFunc { func writeResult(w http.ResponseWriter, status int, body string) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(status) + if _, err := fmt.Fprint(w, body); err != nil { log.Warnf("could not write output: %s", err) } diff --git a/maintenance/health/servicehealthcheck/health_handler_json.go b/maintenance/health/servicehealthcheck/health_handler_json.go index cb54c5029..7acf4ff76 100644 --- a/maintenance/health/servicehealthcheck/health_handler_json.go +++ b/maintenance/health/servicehealthcheck/health_handler_json.go @@ -22,8 +22,11 @@ func JSONHealthHandler() http.HandlerFunc { checkResponse := make(map[string]serviceStats) var errors []string + var warnings []string + status := http.StatusOK + for name, res := range checksResults(&requiredChecks) { scr := serviceStats{ Status: res.State, @@ -33,10 +36,12 @@ func JSONHealthHandler() http.HandlerFunc { if res.State == Err { scr.Error = res.Msg status = http.StatusServiceUnavailable + errors = append(errors, fmt.Sprintf("%s: %s", name, res.Msg)) } else if res.State == Warn { warnings = append(warnings, fmt.Sprintf("%s: %s", name, res.Msg)) } + checkResponse[name] = scr } @@ -50,6 +55,7 @@ func JSONHealthHandler() http.HandlerFunc { scr.Error = res.Msg status = http.StatusServiceUnavailable } + checkResponse[name] = scr } @@ -59,8 +65,8 @@ func JSONHealthHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - err := json.NewEncoder(w).Encode(checkResponse) - if err != nil { + + if err := json.NewEncoder(w).Encode(checkResponse); err != nil { log.Warnf("json health handler endpoint: encoding failed: %v", err) } } diff --git a/maintenance/health/servicehealthcheck/health_handler_json_test.go b/maintenance/health/servicehealthcheck/health_handler_json_test.go index fdfad8ca2..78ff3855b 100644 --- a/maintenance/health/servicehealthcheck/health_handler_json_test.go +++ b/maintenance/health/servicehealthcheck/health_handler_json_test.go @@ -76,12 +76,14 @@ func TestJSONHealthHandler(t *testing.T) { for _, hc := range tc.requiredHC { RegisterHealthCheck(hc.name, hc) } + for _, hc := range tc.optionalHC { RegisterOptionalHealthCheck(hc, hc.name) } testRequest(t, handler, tc.expCode, func(t *testing.T, resBody []byte) { var res map[string]serviceStats + err := json.Unmarshal(resBody, &res) require.NoError(t, err) diff --git a/maintenance/health/servicehealthcheck/health_handler_readable.go b/maintenance/health/servicehealthcheck/health_handler_readable.go index 822bbbf8d..de82c2c39 100644 --- a/maintenance/health/servicehealthcheck/health_handler_readable.go +++ b/maintenance/health/servicehealthcheck/health_handler_readable.go @@ -9,7 +9,7 @@ import ( "strings" ) -// saves length of the longest name for the column width in the table. 20 characters width is the default +// saves length of the longest name for the column width in the table. 20 characters width is the default. var longestCheckName = 20 // ReadableHealthHandler returns the health endpoint with all details about service health. This handler checks @@ -24,13 +24,17 @@ func ReadableHealthHandler() http.HandlerFunc { table := "%-" + strconv.Itoa(longestCheckName) + "s %-3s %s\n" bodyBuilder := &strings.Builder{} bodyBuilder.WriteString("Required Services: \n") + for name, res := range reqChecks { bodyBuilder.WriteString(fmt.Sprintf(table, name, res.State, res.Msg)) + if res.State == Err { status = http.StatusServiceUnavailable } } + bodyBuilder.WriteString("Optional Services: \n") + for name, res := range optChecks { bodyBuilder.WriteString(fmt.Sprintf(table, name, res.State, res.Msg)) // do not change status, as this is optional diff --git a/maintenance/health/servicehealthcheck/health_handler_readable_test.go b/maintenance/health/servicehealthcheck/health_handler_readable_test.go index 210f7c5f4..1e11af72b 100644 --- a/maintenance/health/servicehealthcheck/health_handler_readable_test.go +++ b/maintenance/health/servicehealthcheck/health_handler_readable_test.go @@ -65,6 +65,7 @@ func TestReadableHealthHandler(t *testing.T) { for _, hc := range tc.req { RegisterHealthCheck(hc.name, hc) } + for _, hc := range tc.opt { RegisterOptionalHealthCheck(hc, hc.name) } @@ -75,6 +76,7 @@ func TestReadableHealthHandler(t *testing.T) { results := strings.Split(string(resBody), "Optional Services: \n") reqRes := strings.Split(strings.Split(results[0], "Required Services: \n")[1], "\n") optRes := strings.Split(results[1], "\n") + testListHealthChecks(t, tc.expReq, reqRes) testListHealthChecks(t, tc.expOpt, optRes) }) diff --git a/maintenance/health/servicehealthcheck/healthcheck.go b/maintenance/health/servicehealthcheck/healthcheck.go index f4f3ae82d..f3d840231 100755 --- a/maintenance/health/servicehealthcheck/healthcheck.go +++ b/maintenance/health/servicehealthcheck/healthcheck.go @@ -26,20 +26,20 @@ func (hcf HealthCheckFunc) HealthCheck(ctx context.Context) HealthCheckResult { return hcf(ctx) } -// Initializable is used to mark that a health check needs to be initialized +// Initializable is used to mark that a health check needs to be initialized. type Initializable interface { Init(ctx context.Context) error } -// HealthState describes if a any error or warning occurred during the health check of a service +// HealthState describes if a any error or warning occurred during the health check of a service. type HealthState string const ( - // Err State of a service, if an error occurred during the health check of the service + // Err State of a service, if an error occurred during the health check of the service. Err HealthState = "ERR" - // Warn State of a service, if a warning occurred during the health check of the service + // Warn State of a service, if a warning occurred during the health check of the service. Warn HealthState = "WARN" - // Ok State of a service, if no warning or error occurred during the health check of the service + // Ok State of a service, if no warning or error occurred during the health check of the service. Ok HealthState = "OK" ) @@ -51,40 +51,52 @@ type HealthCheckResult struct { Msg string } -// requiredChecks contains all required registered Health Checks - key:Name +// requiredChecks contains all required registered Health Checks - key:Name. var requiredChecks sync.Map -// optionalChecks contains all optional registered Health Checks - key:Name +// optionalChecks contains all optional registered Health Checks - key:Name. var optionalChecks sync.Map func checksResults(checks *sync.Map) map[string]HealthCheckResult { results := make(map[string]HealthCheckResult) + checks.Range(func(key, value interface{}) bool { - name := key.(string) - result := value.(*ConnectionState).GetState() + name, _ := key.(string) + if name == "" { + return true + } + + state, ok := value.(*ConnectionState) + if !ok { + return true + } + + result := state.GetState() results[name] = result + return true }) + return results } // RegisterHealthCheck registers a required HealthCheck. The name // must be unique. If the health check satisfies the Initializable interface, it // is initialized before it is added. -// It is not possible to add a health check with the same name twice, even if one is required and one is optional +// It is not possible to add a health check with the same name twice, even if one is required and one is optional. func RegisterHealthCheck(name string, hc HealthCheck, opts ...HealthCheckOption) { registerHealthCheck(&requiredChecks, name, hc, opts...) } // RegisterHealthCheckFunc registers a required HealthCheck. The name // must be unique. It is not possible to add a health check with the same name twice, -// even if one is required and one is optional +// even if one is required and one is optional. func RegisterHealthCheckFunc(name string, f HealthCheckFunc, opts ...HealthCheckOption) { RegisterHealthCheck(name, f, opts...) } // RegisterOptionalHealthCheck registers a HealthCheck like RegisterHealthCheck(hc HealthCheck, name string) -// but the health check is only checked for /health/check and not for /health/ +// but the health check is only checked for /health/check and not for /health/. func RegisterOptionalHealthCheck(hc HealthCheck, name string, opts ...HealthCheckOption) { registerHealthCheck(&optionalChecks, name, hc, opts...) } @@ -110,6 +122,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts log.Warnf("tried to register health check with name %q twice", name) return } + if _, inOpt := optionalChecks.Load(name); inOpt { log.Warnf("tried to register health check with name %q twice", name) return @@ -119,7 +132,9 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts if len(name) > longestCheckName { longestCheckName = len(name) } + var bgState ConnectionState + checks.Store(name, &bgState) go func() { @@ -153,6 +168,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts // calculate when the warmup phase should be finished healthCheckStart := time.Now() warmupDeadline := healthCheckStart.Add(hcCfg.warmupDelay) + for { <-timer.C func() { @@ -174,6 +190,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts // Too soon, leave the same state return } + initErr := initHealthCheck(ctx, initHC) if initErr != nil { // Init failed again @@ -196,6 +213,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts }) // sanity trigger a health check, since we can not guarantee what the real implementation does ... go check.HealthCheck(ctx) + return } } @@ -207,11 +225,12 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts }() } -// initHealthCheck will recover from panics and return a proper error +// initHealthCheck will recover from panics and return a proper error. func initHealthCheck(ctx context.Context, initHC Initializable) (err error) { defer func() { if rp := recover(); rp != nil { err = fmt.Errorf("panic: %v", rp) + errors.Handle(ctx, rp) } }() diff --git a/maintenance/health/servicehealthcheck/healthcheck_test.go b/maintenance/health/servicehealthcheck/healthcheck_test.go index 197722d54..3e323cb4c 100644 --- a/maintenance/health/servicehealthcheck/healthcheck_test.go +++ b/maintenance/health/servicehealthcheck/healthcheck_test.go @@ -51,7 +51,7 @@ func TestHandlerHealthCheck(t *testing.T) { for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { resetHealthChecks() - // set warmup for unit testing explicitely to 0 + // set warmup for unit testing explicitly to 0 RegisterHealthCheck(tc.check.name, tc.check, UseWarmup(0)) testRequest(t, handler, tc.expCode, expBody(tc.expBody)) }) @@ -60,6 +60,7 @@ func TestHandlerHealthCheck(t *testing.T) { func TestInitErrorRetryAndCaching(t *testing.T) { handler := HealthHandler() + resetHealthChecks() bgInterval := time.Second @@ -77,7 +78,6 @@ func TestInitErrorRetryAndCaching(t *testing.T) { UseInitErrResultTTL(time.Hour), // Big caching ttl of the init err result ) testRequest(t, handler, http.StatusServiceUnavailable, expBody("ERR: 1 errors and 0 warnings")) - } { @@ -89,6 +89,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { } // No init err, but expect err because of cache hc.initErr = false + waitForBackgroundCheck(bgInterval) testRequest(t, handler, http.StatusServiceUnavailable, expBody("ERR: 1 errors and 0 warnings")) } @@ -124,6 +125,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { // Remove init err, no caching, expect OK hc.initErr = false + waitForBackgroundCheck(bgInterval) testRequest(t, handler, http.StatusOK, expBody("OK")) } @@ -133,6 +135,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { func TestHandlerHealthCheckOptional(t *testing.T) { checkOpt := &mockHealthCheck{name: "TestHandlerHealthCheckErr", healthCheckErr: true} checkReq := &mockHealthCheck{name: "TestOk"} + resetHealthChecks() RegisterHealthCheck(checkReq.name, checkReq) @@ -141,7 +144,7 @@ func TestHandlerHealthCheckOptional(t *testing.T) { testRequest(t, HealthHandler(), http.StatusOK, expBody("OK")) } -// used in testRequest to customise the response body check +// used in testRequest to customise the response body check. type resBodyComparer func(t *testing.T, resBody []byte) // expBody will expect the response body to equal to the passed expected body. @@ -162,11 +165,18 @@ func testRequest(t *testing.T, handler http.Handler, expCode int, expBody resBod rec := httptest.NewRecorder() handler.ServeHTTP(rec, nil) + resp := rec.Result() assert.Equal(t, expCode, resp.StatusCode) - defer resp.Body.Close() + + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + data, err := io.ReadAll(resp.Body) require.NoError(t, err) + if expBody != nil { expBody(t, data) } @@ -179,10 +189,11 @@ func waitForBackgroundCheck(additionalWait ...time.Duration) { if len(additionalWait) > 0 { t += additionalWait[0] } + time.Sleep(t) } -// remove all previous health checks +// remove all previous health checks. func resetHealthChecks() { requiredChecks = sync.Map{} optionalChecks = sync.Map{} diff --git a/maintenance/health/servicehealthcheck/mocks_test.go b/maintenance/health/servicehealthcheck/mocks_test.go index 682b92732..18467c192 100644 --- a/maintenance/health/servicehealthcheck/mocks_test.go +++ b/maintenance/health/servicehealthcheck/mocks_test.go @@ -21,6 +21,7 @@ func (t *mockHealthCheck) Init(_ context.Context) error { if t.initErr { return errors.New("initError") } + return nil } @@ -28,5 +29,6 @@ func (t *mockHealthCheck) HealthCheck(_ context.Context) HealthCheckResult { if t.healthCheckErr { return HealthCheckResult{State: Err, Msg: "healthCheckErr"} } + return HealthCheckResult{State: Ok} } diff --git a/maintenance/log/handler.go b/maintenance/log/handler.go index 8d320eb16..41a00f5dd 100755 --- a/maintenance/log/handler.go +++ b/maintenance/log/handler.go @@ -9,14 +9,14 @@ import ( "time" "github.com/getsentry/sentry-go" - - "github.com/pace/bricks/maintenance/log/hlog" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + + "github.com/pace/bricks/maintenance/log/hlog" ) -// RequestIDHeader name of the header that can contain a request ID +// RequestIDHeader name of the header that can contain a request ID. const RequestIDHeader = "Request-Id" // Handler returns a middleware that handles all of the logging aspects of @@ -40,7 +40,7 @@ func Handler(silentPrefixes ...string) func(http.Handler) http.Handler { } // requestCompleted logs all request related information once -// at the end of the request +// at the end of the request. var requestCompleted = func(r *http.Request, status, size int, duration time.Duration) { ctx := r.Context() @@ -65,7 +65,7 @@ var requestCompleted = func(r *http.Request, status, size int, duration time.Dur Msg("Request Completed") } -// ProxyAwareRemote return the most likely remote address +// ProxyAwareRemote return the most likely remote address. func ProxyAwareRemote(r *http.Request) string { // if we get the content via a proxy, try to extract the // ip from the usual headers @@ -73,10 +73,12 @@ func ProxyAwareRemote(r *http.Request) string { addresses := strings.Split(r.Header.Get(h), ",") for i := len(addresses) - 1; i >= 0; i-- { ip := strings.TrimSpace(addresses[i]) + realIP := net.ParseIP(ip) if !realIP.IsGlobalUnicast() || isPrivate(realIP) { continue // bad address, go to next } + return ip } } @@ -87,12 +89,13 @@ func ProxyAwareRemote(r *http.Request) string { log.Ctx(r.Context()).Warn().Err(err).Msg("failed to decode the remote address") return "" } + return host } // isPrivate reports whether `ip' is a local address, according to // RFC 1918 (IPv4 addresses) and RFC 4193 (IPv6 addresses). -// Remove as soon as https://github.com/golang/go/issues/29146 is resolved +// Remove as soon as https://github.com/golang/go/issues/29146 is resolved. func isPrivate(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { // Local IPv4 addresses are defined in https://tools.ietf.org/html/rfc1918 @@ -109,6 +112,7 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + var id xid.ID // try extract of xid from header diff --git a/maintenance/log/handler_test.go b/maintenance/log/handler_test.go index ca87ad6ea..4bfe57303 100644 --- a/maintenance/log/handler_test.go +++ b/maintenance/log/handler_test.go @@ -11,21 +11,24 @@ import ( func TestLoggingHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) mux := http.NewServeMux() mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { if RequestID(r) == "" { t.Error("Request should have request id") } - w.WriteHeader(201) + + w.WriteHeader(http.StatusCreated) }) Handler()(mux).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { t.Error("expected 201 status code") } } diff --git a/maintenance/log/hlog/hlog.go b/maintenance/log/hlog/hlog.go index 65f801a4a..5de161894 100644 --- a/maintenance/log/hlog/hlog.go +++ b/maintenance/log/hlog/hlog.go @@ -15,7 +15,7 @@ import ( ) // FromRequest gets the logger in the request's context. -// This is a shortcut for log.Ctx(r.Context()) +// This is a shortcut for log.Ctx(r.Context()). func FromRequest(r *http.Request) *zerolog.Logger { return log.Ctx(r.Context()) } @@ -86,6 +86,7 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, host) }) } + next.ServeHTTP(w, r) }) } @@ -102,6 +103,7 @@ func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, ua) }) } + next.ServeHTTP(w, r) }) } @@ -118,6 +120,7 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, ref) }) } + next.ServeHTTP(w, r) }) } @@ -125,7 +128,7 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler { type ( idKey struct{} - traceIdKey struct{} + traceIDKey struct{} ) // IDFromRequest returns the unique id associated to the request if any. @@ -133,6 +136,7 @@ func IDFromRequest(r *http.Request) (id xid.ID, ok bool) { if r == nil { return } + return IDFromCtx(r.Context()) } @@ -144,7 +148,7 @@ func IDFromCtx(ctx context.Context) (id xid.ID, ok bool) { // TraceIDFromCtx returns the trace id associated to the context if any. func TraceIDFromCtx(ctx context.Context) (id string, ok bool) { - id, ok = ctx.Value(traceIdKey{}).(string) + id, ok = ctx.Value(traceIDKey{}).(string) return } @@ -161,21 +165,25 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + id, ok := IDFromRequest(r) if !ok { id = xid.New() ctx = context.WithValue(ctx, idKey{}, id) r = r.WithContext(ctx) } + if fieldKey != "" { log := zerolog.Ctx(ctx) log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str(fieldKey, id.String()) }) } + if headerName != "" { w.Header().Set(headerName, id.String()) } + next.ServeHTTP(w, r) }) } @@ -192,6 +200,7 @@ func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.H return c.Str(fieldKey, val) }) } + next.ServeHTTP(w, r) }) } @@ -218,5 +227,6 @@ func ContextTransfer(parentCtx, out context.Context) context.Context { if !found { return out } + return WithValue(out, id) } diff --git a/maintenance/log/log.go b/maintenance/log/log.go index c14e948fd..e110d2e0e 100644 --- a/maintenance/log/log.go +++ b/maintenance/log/log.go @@ -12,12 +12,11 @@ import ( "time" "github.com/caarlos0/env/v10" - "github.com/pace/bricks/maintenance/log/hlog" - + isatty "github.com/mattn/go-isatty" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - isatty "github.com/mattn/go-isatty" + "github.com/pace/bricks/maintenance/log/hlog" ) type config struct { @@ -26,7 +25,7 @@ type config struct { LogCompletedRequest bool `env:"LOG_COMPLETED_REQUEST" envDefault:"true"` } -// map to translate the string log level +// map to translate the string log level. var levelMap = map[string]zerolog.Level{ "debug": zerolog.DebugLevel, "info": zerolog.InfoLevel, @@ -44,8 +43,7 @@ var ( func init() { // parse log config - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { Fatalf("Failed to parse server environment: %v", err) } @@ -54,7 +52,9 @@ func init() { if !ok { Fatalf("Unknown log level: %q", cfg.LogLevel) } + zerolog.SetGlobalLevel(v) + log.Logger = log.Logger.Level(v) // auto detect log format @@ -80,16 +80,17 @@ func init() { log.Logger = log.Output(logOutput) } -// RequestID returns a unique request id or an empty string if there is none +// RequestID returns a unique request id or an empty string if there is none. func RequestID(r *http.Request) string { id, ok := hlog.IDFromRequest(r) if ok { return id.String() } + return "" } -// RequestIDFromContext returns a unique request id or an empty string if there is none +// RequestIDFromContext returns a unique request id or an empty string if there is none. func RequestIDFromContext(ctx context.Context) string { id, ok := hlog.IDFromCtx(ctx) if ok { @@ -99,7 +100,7 @@ func RequestIDFromContext(ctx context.Context) string { return "" } -// TraceIDFromContext returns a unique request id or an empty string if there is none +// TraceIDFromContext returns a unique request id or an empty string if there is none. func TraceIDFromContext(ctx context.Context) string { id, ok := hlog.TraceIDFromCtx(ctx) if ok { @@ -109,22 +110,22 @@ func TraceIDFromContext(ctx context.Context) string { return "" } -// Req returns the logger for the passed request +// Req returns the logger for the passed request. func Req(r *http.Request) *zerolog.Logger { return hlog.FromRequest(r) } -// Ctx returns the logger for the passed context +// Ctx returns the logger for the passed context. func Ctx(ctx context.Context) *zerolog.Logger { return log.Ctx(ctx) } -// Logger returns the current logger instance +// Logger returns the current logger instance. func Logger() *zerolog.Logger { return &log.Logger } -// Stack prints the stack of the calling goroutine +// Stack prints the stack of the calling goroutine. func Stack(ctx context.Context) { for _, line := range strings.Split(string(debug.Stack()), "\n") { if line != "" { diff --git a/maintenance/log/log_api.go b/maintenance/log/log_api.go index 0501171ee..71b3e1a24 100644 --- a/maintenance/log/log_api.go +++ b/maintenance/log/log_api.go @@ -6,32 +6,32 @@ import ( "github.com/pace/bricks/maintenance/terminationlog" ) -// Fatal implements log Fatal interface +// Fatal implements log Fatal interface. func Fatal(v ...interface{}) { terminationlog.Fatal(v...) } -// Fatalln implements log Fatalln interface +// Fatalln implements log Fatalln interface. func Fatalln(v ...interface{}) { terminationlog.Fatalln(v...) } -// Fatalf implements log Fatalf interface +// Fatalf implements log Fatalf interface. func Fatalf(format string, v ...interface{}) { terminationlog.Fatalf(format, v...) } -// Print implements log Print interface +// Print implements log Print interface. func Print(v ...interface{}) { Debug(v...) } -// Println implements log Println interface +// Println implements log Println interface. func Println(v ...interface{}) { Debug(v...) } -// Printf implements log Printf interface +// Printf implements log Printf interface. func Printf(format string, v ...interface{}) { Debugf(format, v...) } diff --git a/maintenance/log/log_test.go b/maintenance/log/log_test.go index e4780aeb7..ce934d206 100644 --- a/maintenance/log/log_test.go +++ b/maintenance/log/log_test.go @@ -4,12 +4,13 @@ package log import ( "context" + "net/http" "net/http/httptest" "testing" ) func TestLog(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) if RequestID(req) != "" { t.Error("Request without set error ID can't have a request id") } diff --git a/maintenance/log/logrus_api.go b/maintenance/log/logrus_api.go index f8f88b5f1..7daaa147c 100644 --- a/maintenance/log/logrus_api.go +++ b/maintenance/log/logrus_api.go @@ -8,26 +8,26 @@ import ( "github.com/rs/zerolog/log" ) -// Error implements logrus Error interface +// Error implements logrus Error interface. func Error(v ...interface{}) { log.Error().Msg(fmt.Sprint(v...)) } -// Warn implements logrus Warn interface +// Warn implements logrus Warn interface. func Warn(v ...interface{}) { log.Warn().Msg(fmt.Sprint(v...)) } -// Info implements logrus Info interface +// Info implements logrus Info interface. func Info(v ...interface{}) { log.Info().Msg(fmt.Sprint(v...)) } -// Debug implements logrus Debug interface +// Debug implements logrus Debug interface. func Debug(v ...interface{}) { log.Debug().Msg(fmt.Sprint(v...)) } -// Errorf implements logrus Errorf interface +// Errorf implements logrus Errorf interface. func Errorf(format string, v ...interface{}) { log.Error().Msg(fmt.Sprintf(format, v...)) } -// Warnf implements logrus Warnf interface +// Warnf implements logrus Warnf interface. func Warnf(format string, v ...interface{}) { log.Warn().Msg(fmt.Sprintf(format, v...)) } -// Infof implements logrus Infof interface +// Infof implements logrus Infof interface. func Infof(format string, v ...interface{}) { log.Info().Msg(fmt.Sprintf(format, v...)) } -// Debugf implements logrus Debugf interface +// Debugf implements logrus Debugf interface. func Debugf(format string, v ...interface{}) { log.Debug().Msg(fmt.Sprintf(format, v...)) } diff --git a/maintenance/log/sink.go b/maintenance/log/sink.go index c47bef166..b59a990a0 100644 --- a/maintenance/log/sink.go +++ b/maintenance/log/sink.go @@ -22,10 +22,11 @@ const defaultSinkSize = 1000 func ContextWithSink(ctx context.Context, sink *Sink) context.Context { l := log.Ctx(ctx).Output(sink) ctx = l.WithContext(ctx) + return context.WithValue(ctx, sinkKey{}, sink) } -// SinkFromContext returns the Sink of the given context if it exists +// SinkFromContext returns the Sink of the given context if it exists. func SinkFromContext(ctx context.Context) (*Sink, bool) { sink, ok := ctx.Value(sinkKey{}).(*Sink) return sink, ok @@ -33,7 +34,7 @@ func SinkFromContext(ctx context.Context) (*Sink, bool) { // SinkContextTransfer gets the sink from the sourceCtx // and returns a new context based on targetCtx with the -// extracted sink. If no sink is present this is a noop +// extracted sink. If no sink is present this is a noop. func SinkContextTransfer(sourceCtx, targetCtx context.Context) context.Context { sink, ok := SinkFromContext(sourceCtx) if !ok { @@ -45,7 +46,7 @@ func SinkContextTransfer(sourceCtx, targetCtx context.Context) context.Context { // Sink respresents a log sink which is used to store // logs, created with log.Ctx(ctx), inside the context -// and use them at a later point in time +// and use them at a later point in time. type Sink struct { Silent bool customSize int @@ -57,7 +58,7 @@ type Sink struct { } // NewSink initializes a new sink. This will deprecate the public properties -// of the sink struct sometime in the future +// of the sink struct sometime in the future. func NewSink(opts ...SinkOption) *Sink { sink := &Sink{} for _, opt := range opts { @@ -68,6 +69,7 @@ func NewSink(opts ...SinkOption) *Sink { if sink.customSize > 0 { sinkSize = sink.customSize } + sink.ring = newStringRing(sinkSize) return sink @@ -84,6 +86,7 @@ func handlerWithSink(silentPrefixes ...string) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var sink Sink + for _, prefix := range silentPrefixes { if strings.HasPrefix(r.URL.Path, prefix) { sink.Silent = true @@ -107,7 +110,7 @@ func (s *Sink) ToJSON() []byte { // Pretty returns the logs as string while using the // zerolog.ConsoleWriter to format them in a human -// readable way +// readable way. func (s *Sink) Pretty() string { buf := &bytes.Buffer{} writer := &zerolog.ConsoleWriter{ @@ -118,6 +121,7 @@ func (s *Sink) Pretty() string { s.rwmutex.Lock() defer s.rwmutex.Unlock() + for _, str := range s.ring.GetContent() { n, err := strings.NewReader(str).WriteTo(writer) if err != nil { @@ -155,7 +159,7 @@ func (s *Sink) Write(b []byte) (int, error) { // this is required for cases where a sink is created directly // because then the ring will not be created via newStringRing -// and its size may be 0 (causes div by zero error) +// and its size may be 0 (causes div by zero error). func (s *Sink) initBuffer() { if s.ring.size == 0 { s.ring.size = defaultSinkSize @@ -180,6 +184,7 @@ func (r *stringRing) writeString(c string) { r.data = append(r.data, c) return } + if len(r.data) < r.size-1 { // default case: ring has not reached maximum size yet // so just append and increase @@ -193,7 +198,7 @@ func (r *stringRing) writeString(c string) { } } -// GetContent returns the content of the buffer in the order it was written +// GetContent returns the content of the buffer in the order it was written. func (r *stringRing) GetContent() []string { // default case: write pointer has not started overflowing if len(r.data) < r.size { @@ -201,6 +206,7 @@ func (r *stringRing) GetContent() []string { } else { out := r.data[r.nextPos:] out = append(out, r.data[:r.nextPos]...) + return out } } diff --git a/maintenance/log/sink_test.go b/maintenance/log/sink_test.go index b5432265a..425e2d5fa 100644 --- a/maintenance/log/sink_test.go +++ b/maintenance/log/sink_test.go @@ -12,30 +12,35 @@ import ( func Test_Sink(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) var sink *Sink + mux := http.NewServeMux() mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { require.NotEqual(t, "", RequestID(r), "request should have request id") var ok bool + sink, ok = SinkFromContext(r.Context()) require.True(t, ok, "SinkFromContext() returned false unexpectedly") Req(r).Info().Msg("this is a test message for the sink") - w.WriteHeader(201) + w.WriteHeader(http.StatusCreated) }) Handler()(mux).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - require.Equal(t, 201, resp.StatusCode, "wrong status code") + require.Equal(t, http.StatusCreated, resp.StatusCode, "wrong status code") logs := sink.ToJSON() var result []interface{} + require.NoError(t, json.Unmarshal(logs, &result), "could not unmarshal logs") require.Len(t, result, 2, "expecting exactly one log, but got %d", len(result)) @@ -46,12 +51,15 @@ func TestOverflowRing(t *testing.T) { for i := 0; i < 2; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"00", "01"}, ring.data) ring.writeString("02") require.Equal(t, []string{"00", "01", "02"}, ring.data) + for i := 3; i < 5; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"03", "04", "02"}, ring.data) } @@ -60,11 +68,14 @@ func TestRingGetContent(t *testing.T) { for i := 0; i < 2; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"00", "01"}, ring.GetContent()) ring.writeString("02") require.Equal(t, []string{"00", "01", "02"}, ring.GetContent()) + for i := 3; i < 5; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"02", "03", "04"}, ring.GetContent()) } diff --git a/maintenance/metric/handler.go b/maintenance/metric/handler.go index e38528104..953e16edc 100644 --- a/maintenance/metric/handler.go +++ b/maintenance/metric/handler.go @@ -11,7 +11,7 @@ import ( // Handler simply return the prometheus http handler. // The handler will expose all of the collectors and metrics -// that are attached to the prometheus default registry +// that are attached to the prometheus default registry. func Handler() http.Handler { return promhttp.Handler() } diff --git a/maintenance/metric/handler_test.go b/maintenance/metric/handler_test.go index 27d46ae15..b9e3edd06 100644 --- a/maintenance/metric/handler_test.go +++ b/maintenance/metric/handler_test.go @@ -3,20 +3,24 @@ package metric import ( + "net/http" "net/http/httptest" "testing" ) func TestHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) Handler().ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() - if resp.StatusCode != 200 { + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Failed to respond with prometheus metrics: %v", resp.StatusCode) } } diff --git a/maintenance/metric/jsonapi/jsonapi.go b/maintenance/metric/jsonapi/jsonapi.go index 52f8284d2..685f01d41 100644 --- a/maintenance/metric/jsonapi/jsonapi.go +++ b/maintenance/metric/jsonapi/jsonapi.go @@ -101,10 +101,11 @@ func NewMetric(serviceName, path string, w http.ResponseWriter, r *http.Request) // WriteHeader captures the status code for metric submission and // collects the pace_api_http_request_total counter and -// pace_api_http_request_duration_seconds histogram metric +// pace_api_http_request_duration_seconds histogram metric. func (m *Metric) WriteHeader(statusCode int) { clientID, _ := oauth2.ClientID(m.request.Context()) IncPaceAPIHTTPRequestTotal(strconv.Itoa(statusCode), m.request.Method, m.path, m.serviceName, clientID) + duration := float64(time.Since(m.requestStart).Nanoseconds()) / float64(time.Second) AddPaceAPIHTTPRequestDurationSeconds(duration, m.request.Method, m.path, m.serviceName) m.ResponseWriter.WriteHeader(statusCode) @@ -114,10 +115,11 @@ func (m *Metric) WriteHeader(statusCode int) { func (m *Metric) Write(p []byte) (int, error) { size, err := m.ResponseWriter.Write(p) m.sizeWritten += size + return size, err } -// IncPaceAPIHTTPRequestTotal increments the pace_api_http_request_total counter metric +// IncPaceAPIHTTPRequestTotal increments the pace_api_http_request_total counter metric. func IncPaceAPIHTTPRequestTotal(code, method, path, service, clientID string) { paceAPIHTTPRequestTotal.With(prometheus.Labels{ "code": code, @@ -128,7 +130,7 @@ func IncPaceAPIHTTPRequestTotal(code, method, path, service, clientID string) { }).Inc() } -// AddPaceAPIHTTPRequestDurationSeconds adds an observed value for the pace_api_http_request_duration_seconds histogram metric +// AddPaceAPIHTTPRequestDurationSeconds adds an observed value for the pace_api_http_request_duration_seconds histogram metric. func AddPaceAPIHTTPRequestDurationSeconds(duration float64, method, path, service string) { paceAPIHTTPRequestDurationSeconds.With(prometheus.Labels{ "method": method, @@ -147,7 +149,7 @@ func AddPaceAPIHTTPSizeBytes(size float64, method, path, service, requestOrRespo }).Observe(size) } -// lenCallbackReader is a reader that reports the total size before closing +// lenCallbackReader is a reader that reports the total size before closing. type lenCallbackReader struct { reader io.ReadCloser size int @@ -157,6 +159,7 @@ type lenCallbackReader struct { func (r *lenCallbackReader) Read(p []byte) (int, error) { n, err := r.reader.Read(p) r.size += n + return n, err } @@ -165,5 +168,6 @@ func (r *lenCallbackReader) Close() error { n, _ := io.Copy(io.Discard, r.reader) r.size += int(n) r.onEOF(r.size) + return r.reader.Close() } diff --git a/maintenance/metric/jsonapi/jsonapi_test.go b/maintenance/metric/jsonapi/jsonapi_test.go index 1b97b7237..ada6a5d42 100644 --- a/maintenance/metric/jsonapi/jsonapi_test.go +++ b/maintenance/metric/jsonapi/jsonapi_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/pace/bricks/maintenance/metric" ) @@ -16,26 +18,31 @@ func TestMetric(t *testing.T) { t.Run("capture metrics", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test/1234567", nil) + req := httptest.NewRequest(http.MethodGet, "/test/1234567", nil) handler := func(w http.ResponseWriter, r *http.Request) { w = NewMetric("simple", "/test/{id}", w, r) - w.WriteHeader(204) + w.WriteHeader(http.StatusNoContent) } handler(rec, req) - req.Body.Close() // that's something the server does + + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 204 { + if resp.StatusCode != http.StatusNoContent { t.Errorf("Failed to return correct 204 response status, got: %v", resp.StatusCode) } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() @@ -54,22 +61,26 @@ func TestMetric(t *testing.T) { t.Run("capture request size", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/noop", strings.NewReader("some static request body")) + req := httptest.NewRequest(http.MethodPost, "/noop", strings.NewReader("some static request body")) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("noop", "/noop", w, r) } handler(rec, req) - req.Body.Close() // that's something the server does + + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/noop",service="noop",type="req"} 24` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -80,10 +91,11 @@ func TestMetric(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() reqBody := strings.NewReader("some request body") - req := httptest.NewRequest("POST", "/foobar", readerWithoutLen{reqBody}) + req := httptest.NewRequest(http.MethodPost, "/foobar", readerWithoutLen{reqBody}) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("foobar", "/foobar", w, r) + _, err := io.Copy(io.Discard, r.Body) // read request body if err != nil { panic(err) @@ -91,15 +103,19 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/foobar",service="foobar",type="req"} 17` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -110,7 +126,7 @@ func TestMetric(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() reqBody := strings.NewReader("some request body that noone ever reads") - req := httptest.NewRequest("POST", "/barfoo", readerWithoutLen{reqBody}) + req := httptest.NewRequest(http.MethodPost, "/barfoo", readerWithoutLen{reqBody}) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("barfoo", "/barfoo", w, r) @@ -118,15 +134,18 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + + err := req.Body.Close() // that's something the server does + assert.NoError(t, err) }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/barfoo",service="barfoo",type="req"} 39` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -136,10 +155,11 @@ func TestMetric(t *testing.T) { t.Run("capture response size", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/lalala", nil) + req := httptest.NewRequest(http.MethodGet, "/lalala", nil) handler := func(w http.ResponseWriter, r *http.Request) { w = NewMetric("lalala", "/lalala", w, r) + _, err := w.Write([]byte("hehehehe")) if err != nil { panic(err) @@ -147,15 +167,18 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + + err := req.Body.Close() // that's something the server does + assert.NoError(t, err) }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="GET",path="/lalala",service="lalala",type="resp"} 8` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -163,7 +186,7 @@ func TestMetric(t *testing.T) { }) } -// readerWithoutLen is a reader that has definitely not a Len() method +// readerWithoutLen is a reader that has definitely not a Len() method. type readerWithoutLen struct { io.Reader } diff --git a/maintenance/terminationlog/termlog.go b/maintenance/terminationlog/termlog.go index 5bfc206a7..58285d911 100644 --- a/maintenance/terminationlog/termlog.go +++ b/maintenance/terminationlog/termlog.go @@ -16,25 +16,25 @@ import ( var logFile *os.File -// Fatalf implements log Fatalf interface +// Fatalf implements log Fatalf interface. func Fatalf(format string, v ...interface{}) { if logFile != nil { - fmt.Fprintf(logFile, format, v...) + _, _ = fmt.Fprintf(logFile, format, v...) } log.Fatal().Msg(fmt.Sprintf(format, v...)) } -// Fatal implements log Fatal interface +// Fatal implements log Fatal interface. func Fatal(v ...interface{}) { if logFile != nil { - fmt.Fprint(logFile, v...) + _, _ = fmt.Fprint(logFile, v...) } log.Fatal().Msg(fmt.Sprint(v...)) } -// Fatalln implements log Fatalln interface +// Fatalln implements log Fatalln interface. func Fatalln(v ...interface{}) { Fatal(v...) } diff --git a/maintenance/terminationlog/termlog_linux_amd64.go b/maintenance/terminationlog/termlog_linux_amd64.go index 352253392..ac95e5648 100644 --- a/maintenance/terminationlog/termlog_linux_amd64.go +++ b/maintenance/terminationlog/termlog_linux_amd64.go @@ -8,20 +8,22 @@ package terminationlog import ( + "log" "os" "syscall" ) -// termLog default location of kubernetes termination log +// termLog default location of kubernetes termination log. const termLog = "/dev/termination-log" func init() { - file, err := os.OpenFile(termLog, os.O_RDWR, 0o666) - + file, err := os.OpenFile(termLog, os.O_RDWR, 0o600) if err == nil { logFile = file // redirect stderr to the termLog - syscall.Dup2(int(logFile.Fd()), 2) // nolint: errcheck + if err := syscall.Dup2(int(logFile.Fd()), 2); err != nil { + log.Fatal(err) + } } } diff --git a/maintenance/terminationlog/termlog_linux_arm64.go b/maintenance/terminationlog/termlog_linux_arm64.go index a4590d876..ae4b465ec 100644 --- a/maintenance/terminationlog/termlog_linux_arm64.go +++ b/maintenance/terminationlog/termlog_linux_arm64.go @@ -12,7 +12,7 @@ import ( "syscall" ) -// termLog default location of kubernetes termination log +// termLog default location of kubernetes termination log. const termLog = "/dev/termination-log" func init() { @@ -22,6 +22,6 @@ func init() { logFile = file // redirect stderr to the termLog - syscall.Dup3(int(logFile.Fd()), 2, 0) // nolint: errcheck + syscall.Dup3(int(logFile.Fd()), 2, 0) } } diff --git a/maintenance/tracing/tracing.go b/maintenance/tracing/tracing.go index ea4ed5e69..d9d781875 100755 --- a/maintenance/tracing/tracing.go +++ b/maintenance/tracing/tracing.go @@ -9,13 +9,14 @@ import ( "strings" "github.com/getsentry/sentry-go" + "github.com/zenazn/goji/web/mutil" + "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/maintenance/util" - "github.com/zenazn/goji/web/mutil" ) func init() { - var tracesSampleRate float64 = 0.1 + tracesSampleRate := 0.1 val := strings.TrimSpace(os.Getenv("SENTRY_TRACES_SAMPLE_RATE")) if val != "" { @@ -50,7 +51,7 @@ type traceHandler struct { next http.Handler } -// Trace the service function handler execution +// Trace the service function handler execution. func (h *traceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() hub := sentry.CurrentHub() @@ -83,6 +84,7 @@ func (h *traceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() hub.Scope().SetRequest(r) + r = r.WithContext(transaction.Context()) h.next.ServeHTTP(ww, r) diff --git a/maintenance/tracing/tracing_test.go b/maintenance/tracing/tracing_test.go index da5b091a9..fb306ebb5 100644 --- a/maintenance/tracing/tracing_test.go +++ b/maintenance/tracing/tracing_test.go @@ -7,10 +7,11 @@ import ( "net/http/httptest" "testing" - "github.com/pace/bricks/maintenance/util" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/gorilla/mux" + "github.com/pace/bricks/maintenance/util" ) func TestHandlerIgnore(t *testing.T) { @@ -19,7 +20,7 @@ func TestHandlerIgnore(t *testing.T) { r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) // This test does not tests if any prefix is ignored r.ServeHTTP(rec, req) @@ -29,14 +30,20 @@ func TestHandler(t *testing.T) { r := mux.NewRouter() r.Use(Handler()) r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) r.ServeHTTP(rec, req) + resp := rec.Result() + defer func() { + err := resp.Body.Close() + assert.NoError(t, err) + }() + // This test does not tests the tracing - require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode) } diff --git a/maintenance/util/ignore_prefix_handler.go b/maintenance/util/ignore_prefix_handler.go index a456dc7b4..f60798097 100644 --- a/maintenance/util/ignore_prefix_handler.go +++ b/maintenance/util/ignore_prefix_handler.go @@ -9,17 +9,17 @@ import ( ) // configurableHandler is a wrapper for another middleware. -// It only calls the actual middleware if none of the ignoredPrefixes is prefix of the request path +// It only calls the actual middleware if none of the ignoredPrefixes is prefix of the request path. type configurableHandler struct { ignoredPrefixes []string next http.Handler actualHandler http.Handler } -// ConfigurableMiddlewareOption is a functional option to configure the handler +// ConfigurableMiddlewareOption is a functional option to configure the handler. type ConfigurableMiddlewareOption func(*configurableHandler) error -// WithoutPrefixes allows to configure the ignoredPrefix slice +// WithoutPrefixes allows to configure the ignoredPrefix slice. func WithoutPrefixes(prefix ...string) ConfigurableMiddlewareOption { return func(mdw *configurableHandler) error { mdw.ignoredPrefixes = append(mdw.ignoredPrefixes, prefix...) @@ -36,7 +36,7 @@ func NewIgnorePrefixMiddleware(actualMiddleware func(http.Handler) http.Handler, } // NewConfigurableHandler creates a configurableHandler, that wraps anther handler. -// actualHandler is the handler, that is called if the request is not ignored +// actualHandler is the handler, that is called if the request is not ignored. func NewConfigurableHandler(next, actualHandler http.Handler, cfgs ...ConfigurableMiddlewareOption) *configurableHandler { middleware := &configurableHandler{next: next, actualHandler: actualHandler} for _, cfg := range cfgs { @@ -44,11 +44,12 @@ func NewConfigurableHandler(next, actualHandler http.Handler, cfgs ...Configurab log.Fatal(err) } } + return middleware } // ServeHTTP tests if the path of the current request matches with any prefix of the list of ignored prefixes. -// If the Request should be ignored by the actual handler, the next handler is called, otherwise the actual handler is called +// If the Request should be ignored by the actual handler, the next handler is called, otherwise the actual handler is called. func (m configurableHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for _, prefix := range m.ignoredPrefixes { if strings.HasPrefix(r.URL.Path, prefix) { @@ -56,5 +57,6 @@ func (m configurableHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + m.actualHandler.ServeHTTP(w, r) } diff --git a/maintenance/util/ignore_prefix_handler_test.go b/maintenance/util/ignore_prefix_handler_test.go index ee6890c05..b6f0df1af 100644 --- a/maintenance/util/ignore_prefix_handler_test.go +++ b/maintenance/util/ignore_prefix_handler_test.go @@ -35,9 +35,15 @@ func TestMiddlewareWithBlacklist(t *testing.T) { for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) r.ServeHTTP(rec, req) + resp := rec.Result() + + defer func() { + _ = resp.Body.Close() + }() + require.Equal(t, tc.statusCodeExpected, resp.StatusCode) }) } diff --git a/pkg/cache/example_test.go b/pkg/cache/example_test.go index bbed062d6..e5c86956e 100644 --- a/pkg/cache/example_test.go +++ b/pkg/cache/example_test.go @@ -27,6 +27,7 @@ func Example_inMemory() { if err != nil { panic(err) } + fmt.Println(string(v)) // forget diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index 017cd0c9e..c66d718cb 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -36,12 +36,15 @@ func InMemory() *Memory { func (c *Memory) Put(_ context.Context, key string, value []byte, ttl time.Duration) error { v := inMemoryValue{value: make([]byte, len(value))} copy(v.value, value) + if ttl != 0 { v.expiresAt = time.Now().Add(ttl) } + c.mx.Lock() c.values[key] = v c.mx.Unlock() + return nil } @@ -53,9 +56,11 @@ func (c *Memory) Get(ctx context.Context, key string) ([]byte, time.Duration, er c.mx.RLock() v, ok := c.values[key] c.mx.RUnlock() + if !ok { return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } + var ttl time.Duration if !v.expiresAt.IsZero() { ttl = time.Until(v.expiresAt) @@ -64,8 +69,10 @@ func (c *Memory) Get(ctx context.Context, key string) ([]byte, time.Duration, er return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } } + value := make([]byte, len(v.value)) copy(value, v.value) + return value, ttl, nil } diff --git a/pkg/cache/memory_test.go b/pkg/cache/memory_test.go index af57a9537..e56ac6798 100644 --- a/pkg/cache/memory_test.go +++ b/pkg/cache/memory_test.go @@ -5,9 +5,10 @@ package cache_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/pkg/cache" "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) func TestMemory(t *testing.T) { diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index fe3c96ebd..8daf2d538 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -31,10 +31,10 @@ func InRedis(client *redis.Client, prefix string) *Redis { // is given, the cache automatically forgets the value after the duration. If // ttl is zero then it is never automatically forgotten. func (c *Redis) Put(ctx context.Context, key string, value []byte, ttl time.Duration) error { - err := c.client.Set(ctx, c.prefix+key, value, ttl).Err() - if err != nil { - return fmt.Errorf("%w: redis: %s", ErrBackend, err) + if err := c.client.Set(ctx, c.prefix+key, value, ttl).Err(); err != nil { + return fmt.Errorf("%w: redis: %w", ErrBackend, err) } + return nil } @@ -51,26 +51,32 @@ var redisGETAndPTTL = redis.NewScript(`return { // non-nil. func (c *Redis) Get(ctx context.Context, key string) ([]byte, time.Duration, error) { key = c.prefix + key + r, err := redisGETAndPTTL.Run(ctx, c.client, []string{key}).Result() if err != nil { - return nil, 0, fmt.Errorf("%w: redis: %s", ErrBackend, err) + return nil, 0, fmt.Errorf("%w: redis: %w", ErrBackend, err) } + result, ok := r.([]interface{}) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, r, result) } + v := result[0] if v == nil { return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } + value, ok := v.(string) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, v, value) } + ttl, ok := result[1].(int64) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, result[1], ttl) } + switch { case ttl == -1: // key exists but has no associated expire return []byte(value), 0, nil @@ -86,9 +92,9 @@ func (c *Redis) Get(ctx context.Context, key string) ([]byte, time.Duration, err // Forget removes the value stored under the key. No error is returned if there // is no value stored. func (c *Redis) Forget(ctx context.Context, key string) error { - err := c.client.Del(ctx, c.prefix+key).Err() - if err != nil { - return fmt.Errorf("%w: redis: %s", ErrBackend, err) + if err := c.client.Del(ctx, c.prefix+key).Err(); err != nil { + return fmt.Errorf("%w: redis: %w", ErrBackend, err) } + return nil } diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go index a555c6871..e54fc2193 100644 --- a/pkg/cache/redis_test.go +++ b/pkg/cache/redis_test.go @@ -5,16 +5,18 @@ package cache_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/backend/redis" "github.com/pace/bricks/pkg/cache" "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) func TestIntegrationRedis(t *testing.T) { if testing.Short() { t.SkipNow() } + suite.Run(t, &testsuite.CacheTestSuite{ Cache: cache.InRedis(redis.Client(), "test:cache:"), }) diff --git a/pkg/cache/testsuite/cache.go b/pkg/cache/testsuite/cache.go index 3ca906c1f..1f31c25a5 100644 --- a/pkg/cache/testsuite/cache.go +++ b/pkg/cache/testsuite/cache.go @@ -9,9 +9,10 @@ import ( "sync" "time" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/cache" - "github.com/stretchr/testify/suite" ) type CacheTestSuite struct { @@ -24,24 +25,30 @@ func (suite *CacheTestSuite) TestPut() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("does not error", func() { err := c.Put(ctx, "foo", []byte("bar"), time.Second) suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "") // make sure it doesn't exist + suite.Run("accepts all null values", func() { err := c.Put(ctx, "", nil, 0) suite.NoError(err) }) + _ = c.Forget(ctx, "") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { err := c.Put(ctx, "中文پنجابی🥰🥸", []byte("🦤ᐃᓄᒃᑎᑐᑦລາວ"), 0) suite.NoError(err) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up suite.Run("does not error when repeated", func() { @@ -49,6 +56,7 @@ func (suite *CacheTestSuite) TestPut() { err := c.Put(ctx, "foo", []byte("bar"), time.Second) suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("stores a value", func() { @@ -56,6 +64,7 @@ func (suite *CacheTestSuite) TestPut() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("is unaffected from manipulating the input", func() { @@ -65,15 +74,18 @@ func (suite *CacheTestSuite) TestPut() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { err := c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) suite.NoError(err) @@ -82,6 +94,7 @@ func (suite *CacheTestSuite) TestPut() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } @@ -92,6 +105,7 @@ func (suite *CacheTestSuite) TestGet() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("returns the ttl if set", func() { _ = c.Put(ctx, "foo", []byte("bar"), time.Minute) _, ttl, _ := c.Get(ctx, "foo") @@ -99,6 +113,7 @@ func (suite *CacheTestSuite) TestGet() { suite.LessOrEqual(int64(ttl), int64(time.Minute)) suite.Greater(int64(ttl), int64(time.Minute-time.Second)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns 0 as ttl if ttl not set", func() { @@ -106,33 +121,41 @@ func (suite *CacheTestSuite) TestGet() { _, ttl, _ := c.Get(ctx, "foo") suite.Equal(time.Duration(0), ttl) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns not found error", func() { _, _, err := c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns not found if ttl ran out", func() { err := c.Put(ctx, "foo", []byte("bar"), time.Millisecond) // minimum ttl suite.NoError(err) + <-time.After(2 * time.Millisecond) + _, _, err = c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "foo1") // make sure it doesn't exist _ = c.Forget(ctx, "foo2") // make sure it doesn't exist + suite.Run("retrieves the right value", func() { _ = c.Put(ctx, "foo1", []byte("bar1"), 0) _ = c.Put(ctx, "foo2", []byte("bar2"), 0) value1, _, _ := c.Get(ctx, "foo1") value2, _, _ := c.Get(ctx, "foo2") + suite.Equal([]byte("bar1"), value1) suite.Equal([]byte("bar2"), value2) }) + _ = c.Forget(ctx, "foo1") // clean up _ = c.Forget(ctx, "foo2") // clean up @@ -143,6 +166,7 @@ func (suite *CacheTestSuite) TestGet() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("does not produce nil", func() { @@ -150,34 +174,42 @@ func (suite *CacheTestSuite) TestGet() { value, _, _ := c.Get(ctx, "foo") suite.NotNil(value) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "") // make sure it doesn't exist + suite.Run("returns value stored with an empty key", func() { _ = c.Put(ctx, "", []byte("bar"), 0) value, _, _ := c.Get(ctx, "") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { _ = c.Put(ctx, "中文پنجابی🥰🥸", []byte("🦤ᐃᓄᒃᑎᑐᑦລາວ\x00"), 0) value, _, _ := c.Get(ctx, "中文پنجابی🥰🥸") suite.Equal([]byte("🦤ᐃᓄᒃᑎᑐᑦລາວ\x00"), value) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { for i := 0; i <= 5; i++ { _ = c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) } + var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { _, _, err := c.Get(ctx, fmt.Sprintf("foo%d", i)) suite.NoError(err) @@ -186,6 +218,7 @@ func (suite *CacheTestSuite) TestGet() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } @@ -196,12 +229,14 @@ func (suite *CacheTestSuite) TestForget() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("works", func() { _ = c.Put(ctx, "foo", []byte("bar"), 0) _ = c.Forget(ctx, "foo") _, _, err := c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("does not error when repeated", func() { @@ -209,25 +244,31 @@ func (suite *CacheTestSuite) TestForget() { err := c.Forget(ctx, "foo") suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { err := c.Forget(ctx, "中文پنجابی🥰🥸") suite.NoError(err) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { for i := 0; i <= 5; i++ { _ = c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) } + var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { err := c.Forget(ctx, fmt.Sprintf("foo%d", i)) suite.NoError(err) @@ -236,6 +277,7 @@ func (suite *CacheTestSuite) TestForget() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } diff --git a/pkg/cache/testsuite/cache_test.go b/pkg/cache/testsuite/cache_test.go index b58fcdd20..f5826caac 100644 --- a/pkg/cache/testsuite/cache_test.go +++ b/pkg/cache/testsuite/cache_test.go @@ -5,9 +5,10 @@ package testsuite_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/pkg/cache" . "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) // TestStringsTestSuite tests the reference in-memory cache implementation. diff --git a/pkg/context/transfer.go b/pkg/context/transfer.go index d33e3332c..0a3d72174 100755 --- a/pkg/context/transfer.go +++ b/pkg/context/transfer.go @@ -4,6 +4,7 @@ import ( "context" "github.com/getsentry/sentry-go" + http "github.com/pace/bricks/http/middleware" "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/locale" @@ -49,5 +50,6 @@ func TransferExternalDependencyContext(in, out context.Context) context.Context if edc == nil { return out } + return http.ContextWithExternalDependency(out, edc) } diff --git a/pkg/isotime/isotime.go b/pkg/isotime/isotime.go index bd831e0a1..ca6f0e5c6 100644 --- a/pkg/isotime/isotime.go +++ b/pkg/isotime/isotime.go @@ -23,6 +23,7 @@ func ParseISO8601(str string) (time.Time, error) { } var t time.Time + var err error for _, l := range iso8601Layouts { diff --git a/pkg/isotime/isotime_test.go b/pkg/isotime/isotime_test.go index 310e694a8..2a74d7d69 100644 --- a/pkg/isotime/isotime_test.go +++ b/pkg/isotime/isotime_test.go @@ -87,6 +87,7 @@ func TestParseISO8601(t *testing.T) { t.Errorf("ParseISO8601() error = %v, wantErr %v", err, tt.wantErr) return } + if !got.Equal(tt.want) { t.Errorf("ParseISO8601() = %v, want %v", got, tt.want) } diff --git a/pkg/lock/redis/lock.go b/pkg/lock/redis/lock.go index 6fbe55978..022b7a091 100644 --- a/pkg/lock/redis/lock.go +++ b/pkg/lock/redis/lock.go @@ -10,12 +10,12 @@ import ( "sync" "time" - redisbackend "github.com/pace/bricks/backend/redis" - pberrors "github.com/pace/bricks/maintenance/errors" - "github.com/bsm/redislock" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" + + redisbackend "github.com/pace/bricks/backend/redis" + pberrors "github.com/pace/bricks/maintenance/errors" ) var ( @@ -43,6 +43,7 @@ type LockOption func(l *Lock) func NewLock(name string, opts ...LockOption) *Lock { initClient() + l := &Lock{Name: name} for _, opt := range []LockOption{ // default options SetTTL(5 * time.Second), @@ -50,9 +51,11 @@ func NewLock(name string, opts ...LockOption) *Lock { } { opt(l) } + for _, opt := range opts { opt(l) } + return l } @@ -67,6 +70,7 @@ func (l *Lock) Acquire(ctx context.Context) (bool, error) { lock, err := l.locker.Obtain(ctx, l.Name, l.lockTTL, opts) if err != nil { log.Ctx(ctx).Debug().Err(err).Str("lockName", l.Name).Msg("Could not acquire lock") + switch { case errors.Is(err, redislock.ErrNotObtained): return false, nil @@ -76,6 +80,7 @@ func (l *Lock) Acquire(ctx context.Context) (bool, error) { } l.lock = lock + return true, nil } @@ -94,6 +99,7 @@ func (l *Lock) AcquireWait(ctx context.Context) error { } l.lock = lock + return nil } @@ -122,8 +128,9 @@ func (l *Lock) AcquireAndKeepUp(ctx context.Context) (context.Context, context.C defer cancelLock() keepUpLock(lockCtx, lock, l.lockTTL) + err := lock.Release(ctx) - if err != nil && err != redislock.ErrLockNotHeld { + if err != nil && !errors.Is(err, redislock.ErrLockNotHeld) { log.Ctx(lockCtx).Debug().Err(err).Msgf("could not release lock %q", l.Name) } }() @@ -136,6 +143,7 @@ func (l *Lock) AcquireAndKeepUp(ctx context.Context) (context.Context, context.C func keepUpLock(ctx context.Context, lock *redislock.Lock, refreshTTL time.Duration) { refreshInterval := refreshTTL / 5 lockRunsOutIn := refreshTTL // initial value after obtaining the lock + for { select { case <-ctx.Done(): @@ -149,13 +157,15 @@ func keepUpLock(ctx context.Context, lock *redislock.Lock, refreshTTL time.Durat // Try to refresh lock. case <-time.After(refreshInterval): } - if err := lock.Refresh(ctx, refreshTTL, nil); err == redislock.ErrNotObtained { + + if err := lock.Refresh(ctx, refreshTTL, nil); errors.Is(err, redislock.ErrNotObtained) { // Don't return just yet. Get the TTL of the lock and try to // refresh for as long as the TTL is not over. if lockRunsOutIn, err = lock.TTL(ctx); err != nil { log.Ctx(ctx).Debug().Err(err).Msg("could not get ttl of lock") return // assuming we lost the lock } + continue } else if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("could not refresh lock") @@ -177,6 +187,7 @@ func (l *Lock) Release(ctx context.Context) error { if err := l.lock.Release(ctx); err != nil { log.Ctx(ctx).Debug().Err(err).Msg("error releasing redis lock") + switch { case errors.Is(err, redislock.ErrLockNotHeld): // well, since our only goal is that the lock is released, this will suffice @@ -186,6 +197,7 @@ func (l *Lock) Release(ctx context.Context) error { } l.lock = nil + return nil } diff --git a/pkg/lock/redis/lock_test.go b/pkg/lock/redis/lock_test.go index 93c3ca24c..803e31665 100644 --- a/pkg/lock/redis/lock_test.go +++ b/pkg/lock/redis/lock_test.go @@ -29,14 +29,18 @@ func TestIntegration_RedisLock(t *testing.T) { for try := 0; true; try++ { lockCtx, releaseLock, err = lock.AcquireAndKeepUp(ctx) require.NoError(t, err) + if lockCtx == nil { t.Log("Not obtained, try again in 1sec") time.Sleep(time.Second) + continue } + require.NotNil(t, lockCtx) require.NotNil(t, releaseLock) releaseLock() + break } } diff --git a/pkg/redact/context.go b/pkg/redact/context.go index 3c6d5d97a..951125727 100644 --- a/pkg/redact/context.go +++ b/pkg/redact/context.go @@ -6,17 +6,18 @@ import "context" type patternRedactorKey struct{} -// WithContext allows storing the PatternRedactor inside a context for passing it on +// WithContext allows storing the PatternRedactor inside a context for passing it on. func (r *PatternRedactor) WithContext(ctx context.Context) context.Context { return context.WithValue(ctx, patternRedactorKey{}, r) } // Ctx returns the PatternRedactor stored within the context. If no redactor -// has been defined, an empty redactor is returned that does nothing +// has been defined, an empty redactor is returned that does nothing. func Ctx(ctx context.Context) *PatternRedactor { if rd, ok := ctx.Value(patternRedactorKey{}).(*PatternRedactor); ok { return rd.Clone() } + return NewPatternRedactor(RedactionSchemeDoNothing()) } @@ -25,5 +26,6 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if redactor := Ctx(ctx); redactor != nil { return context.WithValue(targetCtx, patternRedactorKey{}, redactor) } + return targetCtx } diff --git a/pkg/redact/default.go b/pkg/redact/default.go index e731e0914..6d302f439 100644 --- a/pkg/redact/default.go +++ b/pkg/redact/default.go @@ -2,7 +2,7 @@ package redact -// redactionSafe last 4 digits are usually concidered safe (e.g. credit cards, iban, ...) +// redactionSafe last 4 digits are usually considered safe (e.g. credit cards, iban, ...) const redactionSafe = 4 var Default *PatternRedactor diff --git a/pkg/redact/middleware/middleware.go b/pkg/redact/middleware/middleware.go index d4942a701..e1b7ba650 100644 --- a/pkg/redact/middleware/middleware.go +++ b/pkg/redact/middleware/middleware.go @@ -8,13 +8,13 @@ import ( "github.com/pace/bricks/pkg/redact" ) -// Redact provides a pattern redactor middleware to the request context +// Redact provides a pattern redactor middleware to the request context. func Redact(next http.Handler) http.Handler { return RedactWithScheme(next, redact.Default) } // RedactWithScheme provides a pattern redactor middleware to the request context -// using the provided scheme +// using the provided scheme. func RedactWithScheme(next http.Handler, redactor *redact.PatternRedactor) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := redactor.WithContext(r.Context()) diff --git a/pkg/redact/pattern.go b/pkg/redact/pattern.go index 831935fd5..1b29b4cd1 100644 --- a/pkg/redact/pattern.go +++ b/pkg/redact/pattern.go @@ -7,7 +7,7 @@ import "regexp" // Sources: // CreditCard: https://www.regular-expressions.info/creditcard.html -// AllPatterns is a list of all default redaction patterns +// AllPatterns is a list of all default redaction patterns. var AllPatterns = []*regexp.Regexp{ PatternIBAN, PatternJWT, @@ -47,9 +47,9 @@ var ( // JCB cards beginning with 2131 or 1800 have 15 digits. JCB cards beginning with 35 have 16 digits. PatternCCJCB = regexp.MustCompile(`(?:2131|1800|35\d{3})\d{11}`) - // PatternJWT JsonWebToken + // PatternJWT JsonWebToken. PatternJWT = regexp.MustCompile(`(?:ey[a-zA-Z0-9=_-]+\.){2}[a-zA-Z0-9=_-]+`) - // PatternBasicAuthBase match any: Basic YW55IGNhcm5hbCBwbGVhcw== does not validate base64 string + // PatternBasicAuthBase match any: Basic YW55IGNhcm5hbCBwbGVhcw== does not validate base64 string. PatternBasicAuthBase64 = regexp.MustCompile(`Authorization: Basic ([a-zA-Z0-9=]*)`) ) diff --git a/pkg/redact/redact.go b/pkg/redact/redact.go index 797655aed..e08fbf3d8 100644 --- a/pkg/redact/redact.go +++ b/pkg/redact/redact.go @@ -11,7 +11,7 @@ type PatternRedactor struct { scheme RedactionScheme } -// NewPatternRedactor creates a new redactor for masking certain patterns +// NewPatternRedactor creates a new redactor for masking certain patterns. func NewPatternRedactor(scheme RedactionScheme) *PatternRedactor { return &PatternRedactor{ scheme: scheme, @@ -23,25 +23,29 @@ func (r *PatternRedactor) Mask(data string) string { if pattern == nil { continue } + data = pattern.ReplaceAllStringFunc(data, r.scheme) } + return data } -// AddPattern adds patterns to the redactor +// AddPatterns adds patterns to the redactor. func (r *PatternRedactor) AddPatterns(patterns ...*regexp.Regexp) { r.patterns = append(r.patterns, patterns...) } -// RemovePattern deletes a pattern from the redactor +// RemovePattern deletes a pattern from the redactor. func (r *PatternRedactor) RemovePattern(pattern *regexp.Regexp) { index := -1 + for i, p := range r.patterns { if p == pattern || p.String() == pattern.String() { index = i break } } + if index >= 0 { r.patterns = append(r.patterns[:index], r.patterns[index+1:]...) } @@ -55,5 +59,6 @@ func (r *PatternRedactor) Clone() *PatternRedactor { rc := NewPatternRedactor(r.scheme) rc.patterns = make([]*regexp.Regexp, len(r.patterns)) copy(rc.patterns, r.patterns) + return rc } diff --git a/pkg/redact/redact_test.go b/pkg/redact/redact_test.go index df9106814..69e1d5f31 100644 --- a/pkg/redact/redact_test.go +++ b/pkg/redact/redact_test.go @@ -6,9 +6,9 @@ import ( "regexp" "testing" - "github.com/pace/bricks/pkg/redact" - "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/pkg/redact" ) func TestRedactionSchemeKeepLast(t *testing.T) { @@ -30,6 +30,7 @@ and a ********************ring, as well as ****************cret` res := redactor.Mask(originalString) assert.Equal(t, expectedString1, res) redactor.RemovePattern(regexp.MustCompile("DE12345678909876543210")) + res = redactor.Mask(originalString) assert.Equal(t, expectedString2, res) } diff --git a/pkg/redact/scheme.go b/pkg/redact/scheme.go index 960acd799..90f6692da 100644 --- a/pkg/redact/scheme.go +++ b/pkg/redact/scheme.go @@ -7,7 +7,7 @@ import "strings" type RedactionScheme func(string) string // RedactionSchemeDoNothing doesn't redact any values -// Note: only use for testing +// Note: only use for testing. func RedactionSchemeDoNothing() func(string) string { return func(old string) string { return old @@ -15,19 +15,20 @@ func RedactionSchemeDoNothing() func(string) string { } // RedactionSchemeKeepLast replaces all runes in the string with an asterisk -// except the last NUM runes +// except the last NUM runes. func RedactionSchemeKeepLast(num int) func(string) string { return func(old string) string { runes := []rune(old) for i := 0; i < len(runes)-num; i++ { runes[i] = '*' } + return string(runes) } } -// RedactionSchemeKeepLast replaces all runes in the string with an asterisk -// except the last NUM runes +// RedactionSchemeKeepLastJWTNoSignature replaces all runes in the string with an asterisk +// except the last NUM runes. func RedactionSchemeKeepLastJWTNoSignature(num int) func(string) string { defaultScheme := RedactionSchemeKeepLast(num) @@ -35,6 +36,7 @@ func RedactionSchemeKeepLastJWTNoSignature(num int) func(string) string { if PatternJWT.Match([]byte(s)) { parts := strings.Split(s, ".") parts[2] = defaultScheme(parts[2]) + return strings.Join(parts, ".") } diff --git a/pkg/routine/backoff.go b/pkg/routine/backoff.go index 3b7baf3e0..98f9d4361 100644 --- a/pkg/routine/backoff.go +++ b/pkg/routine/backoff.go @@ -28,5 +28,6 @@ func (all combinedExponentialBackoff) Duration(key interface{}) (dur time.Durati backoff.Reset() } } + return } diff --git a/pkg/routine/cluster_background_task_test.go b/pkg/routine/cluster_background_task_test.go index f89bc8352..d0b16995d 100644 --- a/pkg/routine/cluster_background_task_test.go +++ b/pkg/routine/cluster_background_task_test.go @@ -15,8 +15,9 @@ import ( "testing" "time" - "github.com/pace/bricks/pkg/routine" "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/pkg/routine" ) func Example_clusterBackgroundTask() { @@ -40,6 +41,7 @@ func Example_clusterBackgroundTask() { default: } out <- fmt.Sprintf("task run %d", i) + time.Sleep(100 * time.Millisecond) } }, @@ -56,6 +58,7 @@ func Example_clusterBackgroundTask() { for i := 0; i < 3; i++ { println(<-out) } + cancel() // Output: @@ -83,11 +86,13 @@ func TestIntegrationRunNamed_clusterBackgroundTask(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) + go func() { spawnProcess(&buf) wg.Done() }() } + wg.Wait() // until both processes are done exp := `task run 0 @@ -101,18 +106,19 @@ task run 2 } func spawnProcess(w io.Writer) { - cmd := exec.Command(os.Args[0], + cmd := exec.Command(os.Args[0], //nolint:gosec "-test.timeout=2s", "-test.run=Example_clusterBackgroundTask", ) + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "ROUTINE_REDIS_LOCK_TTL=200ms", ) cmd.Stdout = w cmd.Stderr = w - err := cmd.Run() - if err != nil { + + if err := cmd.Run(); err != nil { _, _ = w.Write([]byte("error starting subprocess: " + err.Error())) } } @@ -134,12 +140,14 @@ func (b *subprocessOutputBuffer) Write(p []byte) (int, error) { strings.Contains(s, "Redis connection pool created"): return len(p), nil } + return b.buf.Write(p) } func (b *subprocessOutputBuffer) String() string { b.mx.Lock() defer b.mx.Unlock() + return b.buf.String() } @@ -151,5 +159,6 @@ func println(s string) { // go around the test runner _, _ = log.Writer().Write([]byte(s + "\n")) } + fmt.Println(s) } diff --git a/pkg/routine/instance.go b/pkg/routine/instance.go index 01c6911f4..8c38b52a3 100755 --- a/pkg/routine/instance.go +++ b/pkg/routine/instance.go @@ -8,10 +8,10 @@ import ( "time" "github.com/getsentry/sentry-go" + exponential "github.com/jpillora/backoff" + "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/pkg/lock/redis" - - exponential "github.com/jpillora/backoff" ) type routineThatKeepsRunningOneInstance struct { @@ -38,8 +38,10 @@ func (r *routineThatKeepsRunningOneInstance) Run(ctx context.Context) { "routine": &exponential.Backoff{Min: r.retryInterval, Max: 10 * time.Minute}, } - r.num = ctx.Value(ctxNumKey{}).(int64) + r.num, _ = ctx.Value(ctxNumKey{}).(int64) + var tryAgainIn time.Duration // zero on first run + for { select { case <-ctx.Done(): @@ -50,6 +52,7 @@ func (r *routineThatKeepsRunningOneInstance) Run(ctx context.Context) { // after the routine returned. singleRunCtx, cancel := context.WithCancel(ctx) tryAgainIn = r.singleRun(singleRunCtx) + cancel() } } @@ -59,14 +62,18 @@ func (r *routineThatKeepsRunningOneInstance) Run(ctx context.Context) { // should be performed. func (r *routineThatKeepsRunningOneInstance) singleRun(ctx context.Context) time.Duration { l := redis.NewLock("routine:lock:"+r.Name, redis.SetTTL(r.lockTTL)) + lockCtx, cancel, err := l.AcquireAndKeepUp(ctx) if err != nil { go errors.Handle(ctx, err) // report error to Sentry, non-blocking return r.backoff.Duration("lock") } + if lockCtx != nil { defer cancel() + routinePanicked := true + func() { defer errors.HandleWithCtx(ctx, fmt.Sprintf("routine %d", r.num)) // handle panics @@ -74,12 +81,16 @@ func (r *routineThatKeepsRunningOneInstance) singleRun(ctx context.Context) time defer span.Finish() r.Routine(span.Context()) + routinePanicked = false }() + if routinePanicked { return r.backoff.Duration("routine") } } + r.backoff.ResetAll() + return r.retryInterval } diff --git a/pkg/routine/routine.go b/pkg/routine/routine.go index e8614534f..c1d910915 100755 --- a/pkg/routine/routine.go +++ b/pkg/routine/routine.go @@ -13,6 +13,7 @@ import ( "syscall" "github.com/getsentry/sentry-go" + "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/log" ) @@ -119,7 +120,8 @@ func Run(ctx context.Context, routine func(context.Context)) (cancel context.Can routine(span.Context()) }() - return + + return //nolint:nakedret } type ctxNumKey struct{} @@ -136,6 +138,7 @@ var ( func init() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + go func() { <-c // block until SIGINT/SIGTERM is received signal.Stop(c) @@ -147,6 +150,7 @@ func init() { Int("count", len(contexts)). Ints64("routines", routineNumbers()). Msg("received shutdown signal, canceling all running routines") + for _, cancel := range contexts { cancel() } @@ -158,5 +162,6 @@ func routineNumbers() []int64 { for num := range contexts { routines = append(routines, num) } + return routines } diff --git a/pkg/routine/routine_test.go b/pkg/routine/routine_test.go index 788aef69a..538138665 100644 --- a/pkg/routine/routine_test.go +++ b/pkg/routine/routine_test.go @@ -45,6 +45,7 @@ func TestRun_transfersLogger(t *testing.T) { func TestRun_transfersSink(t *testing.T) { var sink log.Sink + logger := log.Logger() ctx := log.ContextWithSink(logger.WithContext(context.Background()), &sink) waitForRun(ctx, func(ctx context.Context) { @@ -75,6 +76,7 @@ func TestRun_transfersOAuth2Token(t *testing.T) { func TestRun_cancelsContextAfterRoutineIsFinished(t *testing.T) { routineCtx := contextAfterRun(context.Background(), nil) + require.Eventually(t, func() bool { return routineCtx.Err() == context.Canceled }, time.Second, time.Millisecond) @@ -86,12 +88,15 @@ func TestRun_blocksAfterShutdown(t *testing.T) { func testRunBlocksAfterShutdown(t *testing.T) { var endOfTest sync.WaitGroup + endOfTest.Add(1) // start routine that gets canceled by the shutdown routineCtx := make(chan context.Context) + Run(context.Background(), func(ctx context.Context) { routineCtx <- ctx + endOfTest.Wait() }) @@ -134,19 +139,24 @@ func TestRun_cancelsContextsOnSIGTERM(t *testing.T) { func testRunCancelsContextsOn(t *testing.T, signum syscall.Signal) { var endOfTest, routinesStarted sync.WaitGroup + endOfTest.Add(1) // start a few routines routineContexts := [3]context.Context{} routinesStarted.Add(len(routineContexts)) + for i := range routineContexts { i := i + Run(context.Background(), func(ctx context.Context) { routineContexts[i] = ctx + routinesStarted.Done() endOfTest.Wait() }) } + routinesStarted.Wait() // kill this process @@ -170,7 +180,8 @@ func exitAfterTest(t *testing.T, name string, testFunc func(*testing.T)) { testFunc(t) os.Exit(0) } - cmd := exec.Command(os.Args[0], "-test.run="+name) + + cmd := exec.Command(os.Args[0], "-test.run="+name) //nolint:gosec cmd.Env = append(os.Environ(), "ROUTINE_EXIT_AFTER_TEST=1") require.NoError(t, cmd.Run()) } @@ -178,6 +189,7 @@ func exitAfterTest(t *testing.T, name string, testFunc func(*testing.T)) { // Calls Run and returns once the routine is finished. func waitForRun(ctx context.Context, routine func(context.Context)) { done := make(chan struct{}) + Run(ctx, func(ctx context.Context) { defer func() { done <- struct{}{} }() routine(ctx) @@ -189,12 +201,15 @@ func waitForRun(ctx context.Context, routine func(context.Context)) { // routine is finished. func contextAfterRun(ctx context.Context, routine func(context.Context)) context.Context { var routineCtx context.Context + waitForRun(ctx, func(ctx context.Context) { if routine != nil { routine(ctx) } + routineCtx = ctx }) + return routineCtx } diff --git a/pkg/synctx/wg.go b/pkg/synctx/wg.go index 44f9d5faf..760eb0eec 100644 --- a/pkg/synctx/wg.go +++ b/pkg/synctx/wg.go @@ -6,14 +6,15 @@ import ( "sync" ) -// WaitGroup extended with Finish func +// WaitGroup extended with Finish func. type WaitGroup struct { sync.WaitGroup } -// Finish allows to be used easily with go contexts +// Finish allows to be used easily with go contexts. func (wg *WaitGroup) Finish() <-chan struct{} { ch := make(chan struct{}) go func() { wg.Wait(); close(ch) }() + return ch } diff --git a/pkg/synctx/work_queue.go b/pkg/synctx/work_queue.go index d5b3693f5..039166894 100644 --- a/pkg/synctx/work_queue.go +++ b/pkg/synctx/work_queue.go @@ -9,11 +9,11 @@ import ( ) // WorkFunc a function that receives an context and optionally returns -// an error. Returning an error will cancel all other worker functions +// an error. Returning an error will cancel all other worker functions. type WorkFunc func(ctx context.Context) error // WorkQueue is a work queue implementation that respects cancellation -// using contexts +// using contexts. type WorkQueue struct { wg WaitGroup mu sync.Mutex @@ -24,9 +24,10 @@ type WorkQueue struct { } // NewWorkQueue creates a new WorkQueue that respects -// the passed context for cancellation +// the passed context for cancellation. func NewWorkQueue(ctx context.Context) *WorkQueue { ctx, cancel := context.WithCancel(ctx) + return &WorkQueue{ ctx: ctx, done: make(chan struct{}), @@ -39,36 +40,32 @@ func NewWorkQueue(ctx context.Context) *WorkQueue { // will be immediately executed. func (queue *WorkQueue) Add(description string, fn WorkFunc) { queue.wg.Add(1) + go func() { - err := fn(queue.ctx) - // if one of the work queue items fails the whole - // queue will be canceled - if err != nil { - queue.setErr(fmt.Errorf("failed to %s: %v", description, err)) + if err := fn(queue.ctx); err != nil { + queue.setErr(fmt.Errorf("failed to %s: %w", description, err)) queue.cancel() } + queue.wg.Done() }() } // Wait waits until all worker functions are done, -// one worker is failing or the context is canceled +// one worker is failing or the context is canceled. func (queue *WorkQueue) Wait() { defer queue.cancel() select { case <-queue.wg.Finish(): case <-queue.ctx.Done(): - err := queue.ctx.Err() - // if the queue was canceled and no error was set already - // store the error - if err != nil { + if err := queue.ctx.Err(); err != nil { queue.setErr(err) } } } -// Err returns the error if one of the work queue items failed +// Err returns the error if one of the work queue items failed. func (queue *WorkQueue) Err() error { queue.mu.Lock() defer queue.mu.Unlock() @@ -76,7 +73,7 @@ func (queue *WorkQueue) Err() error { return queue.err } -// setErr sets the error on the queue if not set already +// setErr sets the error on the queue if not set already. func (queue *WorkQueue) setErr(err error) { queue.mu.Lock() defer queue.mu.Unlock() diff --git a/pkg/synctx/work_queue_test.go b/pkg/synctx/work_queue_test.go index 871b89218..cb8f87a66 100644 --- a/pkg/synctx/work_queue_test.go +++ b/pkg/synctx/work_queue_test.go @@ -13,6 +13,7 @@ func TestWorkQueueNoTask(t *testing.T) { ctx := context.Background() q := NewWorkQueue(ctx) q.Wait() + if q.Err() != nil { t.Error("expected no error") } @@ -25,10 +26,12 @@ func TestWorkQueueOneTask(t *testing.T) { if ctx1 == ctx { t.Error("should not directly pass the context") } + return nil }) q.Wait() + if q.Err() != nil { t.Error("expected no error") } @@ -42,10 +45,12 @@ func TestWorkQueueOneTaskWithErr(t *testing.T) { }) q.Wait() + if q.Err() == nil { t.Error("expected error") return } + expected := "failed to some work: Some error" if q.Err().Error() != expected { t.Errorf("expected error %q, got: %q", q.Err().Error(), expected) @@ -55,6 +60,7 @@ func TestWorkQueueOneTaskWithErr(t *testing.T) { func TestWorkQueueOneTaskWithCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() + q := NewWorkQueue(ctx) q.Add("some work", func(ctx context.Context) error { time.Sleep(10 * time.Millisecond) @@ -62,10 +68,12 @@ func TestWorkQueueOneTaskWithCancel(t *testing.T) { }) q.Wait() + if q.Err() == nil { t.Error("expected error") return } + expected := "context canceled" if q.Err().Error() != expected { t.Errorf("expected error %q, got: %q", q.Err().Error(), expected) diff --git a/pkg/tracking/utm/context.go b/pkg/tracking/utm/context.go index 256e43bdd..9e5b1911e 100644 --- a/pkg/tracking/utm/context.go +++ b/pkg/tracking/utm/context.go @@ -6,7 +6,6 @@ type ctxKey struct{} var key = ctxKey{} -// https://en.wikipedia.org/wiki/UTM_parameters type UTMData struct { Source string Medium string @@ -46,6 +45,7 @@ func ContextWithUTMData(parentCtx context.Context, data UTMData) context.Context func FromContext(ctx context.Context) (UTMData, bool) { val := ctx.Value(key) data, found := val.(UTMData) + return data, found } @@ -54,5 +54,6 @@ func ContextTransfer(in, out context.Context) context.Context { if !exists { return out // do nothing } + return ContextWithUTMData(out, utmData) } diff --git a/pkg/tracking/utm/context_test.go b/pkg/tracking/utm/context_test.go index b355d476c..2acfef17e 100644 --- a/pkg/tracking/utm/context_test.go +++ b/pkg/tracking/utm/context_test.go @@ -20,6 +20,7 @@ func TestContextWithUTMData(t *testing.T) { ctxWithData := ContextWithUTMData(ctx, data) _, found := FromContext(ctx) assert.False(t, found) + dataFromCtx, found := FromContext(ctxWithData) assert.True(t, found) assert.Equal(t, data, dataFromCtx) diff --git a/pkg/tracking/utm/http.go b/pkg/tracking/utm/http.go index 6cc7e45ab..45523aece 100644 --- a/pkg/tracking/utm/http.go +++ b/pkg/tracking/utm/http.go @@ -21,6 +21,7 @@ func FromRequest(req *http.Request) (UTMData, error) { if data == emptyData { return emptyData, ErrNotFound } + return data, nil } @@ -28,6 +29,7 @@ func AttachToRequest(data UTMData, req *http.Request) *http.Request { if data == emptyData { return req } + q := req.URL.Query() q.Set("utm_source", data.Source) q.Set("utm_medium", data.Medium) @@ -35,11 +37,13 @@ func AttachToRequest(data UTMData, req *http.Request) *http.Request { q.Set("utm_term", data.Term) q.Set("utm_content", data.Content) q.Set("utm_partner_client", data.Client) + req.URL.RawQuery = q.Encode() + return req } -// Middleware attempts to attach utm data found in the request to the request context +// Middleware attempts to attach utm data found in the request to the request context. func Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -48,10 +52,12 @@ func Middleware() func(http.Handler) http.Handler { next.ServeHTTP(w, r) return } + clientID, found := oauth2.ClientID(r.Context()) if found && data.Client == "" { data.Client = clientID } + ctx := ContextWithUTMData(r.Context(), data) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -73,8 +79,10 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if !found { // no utm data found, skip directly to next roundtripper return r.transport.RoundTrip(req) } + newReq := cloneRequest(req) newReq = AttachToRequest(data, newReq) + return r.transport.RoundTrip(newReq) } @@ -98,5 +106,6 @@ func cloneRequest(r *http.Request) *http.Request { for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } + return r2 } diff --git a/pkg/tracking/utm/http_test.go b/pkg/tracking/utm/http_test.go index 6fcffbed5..5d0927cab 100644 --- a/pkg/tracking/utm/http_test.go +++ b/pkg/tracking/utm/http_test.go @@ -51,7 +51,7 @@ func TestRoundTripper_RoundTrip(t *testing.T) { }) req := httptest.NewRequest(http.MethodGet, "http://example.org/?utm_source=internet", nil) req = req.WithContext(ctx) - resp, err := tripper.RoundTrip(req) + resp, err := tripper.RoundTrip(req) //nolint:bodyclose require.NoError(t, err) require.NotNil(t, resp) } @@ -71,5 +71,6 @@ func (m *mockTripper) RoundTrip(req *http.Request) (*http.Response, error) { assert.Equal(m.t, v, h, fmt.Sprintf("expected query paramater %q to match value", k)) } } + return m.resp, nil } diff --git a/test/livetest/init.go b/test/livetest/init.go index b2f2532ed..d6d15ca59 100644 --- a/test/livetest/init.go +++ b/test/livetest/init.go @@ -40,8 +40,7 @@ func init() { prometheus.MustRegister(paceLivetestDurationSeconds) // parse log config - err := env.Parse(&cfg) - if err != nil { + if err := env.Parse(&cfg); err != nil { log.Fatalf("Failed to parse livetest environment: %v", err) } } diff --git a/test/livetest/livetest.go b/test/livetest/livetest.go index 97d94ba61..dcdd63714 100644 --- a/test/livetest/livetest.go +++ b/test/livetest/livetest.go @@ -4,14 +4,16 @@ package livetest import ( "context" + "errors" "fmt" "time" "github.com/getsentry/sentry-go" + "github.com/pace/bricks/maintenance/log" ) -// TestFunc represents a single test (possibly with sub tests) +// TestFunc represents a single test (possibly with sub tests). type TestFunc func(t *T) // Test executes the passed tests in the given order (array order). @@ -45,8 +47,7 @@ func testRun(ctx context.Context, tests []TestFunc) { Int("test", i+1).Logger() ctx = logger.WithContext(ctx) - err = executeTest(ctx, test, fmt.Sprintf("test-%d", i+1)) - if err != nil { + if err := executeTest(ctx, test, fmt.Sprintf("test-%d", i+1)); err != nil { break } } @@ -67,10 +68,20 @@ func executeTest(ctx context.Context, t TestFunc, name string) error { proxy := NewTestProxy(ctx, name) startTime := time.Now() + func() { defer func() { err := recover() - if err != nil && (err != ErrSkipNow || err != ErrFailNow) { + if err == nil { + return + } + + recoveredErr, ok := err.(error) + if !ok { + return + } + + if !errors.Is(recoveredErr, ErrSkipNow) || !errors.Is(recoveredErr, ErrFailNow) { logger.Error().Msgf("PANIC: %+v", err) log.Stack(ctx) proxy.Fail() @@ -79,7 +90,9 @@ func executeTest(ctx context.Context, t TestFunc, name string) error { t(proxy) }() + duration := float64(time.Since(startTime)) / float64(time.Second) + proxy.okIfNoSkipFail() paceLivetestDurationSeconds.WithLabelValues(cfg.ServiceName).Observe(duration) diff --git a/test/livetest/livetest_example_test.go b/test/livetest/livetest_example_test.go index 908aba1fe..f8fc36014 100644 --- a/test/livetest/livetest_example_test.go +++ b/test/livetest/livetest_example_test.go @@ -4,6 +4,7 @@ package livetest_test import ( "context" + "errors" "log" "time" @@ -49,7 +50,7 @@ func ExampleTest() { t.Errorf("formatted") }, }) - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) { log.Fatal(err) } // Output: diff --git a/test/livetest/livetest_test.go b/test/livetest/livetest_test.go index bc32a2a76..d065ce915 100644 --- a/test/livetest/livetest_test.go +++ b/test/livetest/livetest_test.go @@ -4,11 +4,14 @@ package livetest import ( "context" + "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/pace/bricks/maintenance/metric" ) @@ -56,14 +59,13 @@ func TestIntegrationExample(t *testing.T) { t.Errorf("formatted") }, }) - if err != context.DeadlineExceeded { - t.Error(err) - return - } - req := httptest.NewRequest("GET", "/metrics", nil) + require.ErrorIs(t, err, context.DeadlineExceeded) + + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) resp := httptest.NewRecorder() metric.Handler().ServeHTTP(resp, req) + body := resp.Body.String() sn := cfg.ServiceName diff --git a/test/livetest/test_proxy.go b/test/livetest/test_proxy.go index 90f35b3d0..6a0090aed 100644 --- a/test/livetest/test_proxy.go +++ b/test/livetest/test_proxy.go @@ -10,27 +10,27 @@ import ( "github.com/pace/bricks/maintenance/log" ) -// ErrSkipNow is used as a panic if ErrSkipNow is called on the test +// ErrSkipNow is used as a panic if ErrSkipNow is called on the test. var ErrSkipNow = errors.New("skipped test") -// ErrFailNow is used as a panic if ErrFailNow is called on the test +// ErrFailNow is used as a panic if ErrFailNow is called on the test. var ErrFailNow = errors.New("failed test") -// TestState represents the state of a test +// TestState represents the state of a test. type TestState string var ( - // StateRunning first state + // StateRunning first state. StateRunning TestState = "running" - // StateOK test was executed without failure + // StateOK test was executed without failure. StateOK TestState = "ok" - // StateFailed test was executed with failure + // StateFailed test was executed with failure. StateFailed TestState = "failed" - // StateSkipped test was skipped + // StateSkipped test was skipped. StateSkipped TestState = "skipped" ) -// T implements a similar interface than testing.T +// T implements a similar interface than testing.T. type T struct { name string ctx context.Context @@ -45,97 +45,102 @@ func NewTestProxy(ctx context.Context, name string) *T { // Context returns the livetest context. Useful // for passing timeout and/or logging constraints from -// the test executor to the individual case +// the test executor to the individual case. func (t *T) Context() context.Context { return t.ctx } -// Error logs an error message with the test +// Error logs an error message with the test. func (t *T) Error(args ...interface{}) { log.Ctx(t.ctx).Error().Msg(fmt.Sprint(args...)) t.Fail() } -// Errorf logs an error message with the test +// Errorf logs an error message with the test. func (t *T) Errorf(format string, args ...interface{}) { log.Ctx(t.ctx).Error().Msgf(format, args...) t.Fail() } -// Fail marks the test as failed +// Fail marks the test as failed. func (t *T) Fail() { log.Ctx(t.ctx).Info().Msg("Fail...") + if t.state == StateRunning { t.state = StateFailed } } -// FailNow marks the test as failed and skips further execution +// FailNow marks the test as failed and skips further execution. func (t *T) FailNow() { t.Fail() panic(ErrFailNow) } -// Failed returns true if the test was marked as failed +// Failed returns true if the test was marked as failed. func (t *T) Failed() bool { return t.state == StateFailed } -// Fatal logs the passed message in the context of the test and fails the test +// Fatal logs the passed message in the context of the test and fails the test. func (t *T) Fatal(args ...interface{}) { log.Ctx(t.ctx).Error().Msg(fmt.Sprint(args...)) t.FailNow() } -// Fatalf logs the passed message in the context of the test and fails the test +// Fatalf logs the passed message in the context of the test and fails the test. func (t *T) Fatalf(format string, args ...interface{}) { log.Ctx(t.ctx).Error().Msgf(format, args...) t.FailNow() } -// Log logs the passed message in the context of the test +// Log logs the passed message in the context of the test. func (t *T) Log(args ...interface{}) { log.Ctx(t.ctx).Info().Msg(fmt.Sprint(args...)) } -// Logf logs the passed message in the context of the test +// Logf logs the passed message in the context of the test. func (t *T) Logf(format string, args ...interface{}) { log.Ctx(t.ctx).Info().Msgf(format, args...) } -// Name returns the name of the test +// Name returns the name of the test. func (t *T) Name() string { return t.name } -// Skip logs reason and marks the test as skipped +// Skip logs reason and marks the test as skipped. func (t *T) Skip(args ...interface{}) { log.Ctx(t.ctx).Info().Msg("Skip...") log.Ctx(t.ctx).Info().Msg(fmt.Sprint(args...)) + if t.state == StateRunning { t.state = StateSkipped } } -// SkipNow skips the test immediately +// SkipNow skips the test immediately. func (t *T) SkipNow() { log.Ctx(t.ctx).Info().Msg("Skip...") + if t.state == StateRunning { t.state = StateSkipped } + panic(ErrSkipNow) } -// Skipf marks the test as skippend and log a reason +// Skipf marks the test as skippend and log a reason. func (t *T) Skipf(format string, args ...interface{}) { log.Ctx(t.ctx).Info().Msg("Skip...") log.Ctx(t.ctx).Info().Msgf(format, args...) + if t.state == StateRunning { t.state = StateSkipped } } -// Skipped returns true if the test was skipped +// Skipped returns true if the test was skipped. func (t *T) Skipped() bool { return t.state == StateSkipped } diff --git a/tools/jsonapigen/main.go b/tools/jsonapigen/main.go index a2851902c..cfe08e488 100644 --- a/tools/jsonapigen/main.go +++ b/tools/jsonapigen/main.go @@ -26,7 +26,7 @@ func main() { log.Fatal(err) } - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } diff --git a/tools/testserver/main.go b/tools/testserver/main.go index 6e8fbc3b8..66e55cf2d 100755 --- a/tools/testserver/main.go +++ b/tools/testserver/main.go @@ -10,21 +10,20 @@ import ( "time" "github.com/getsentry/sentry-go" - "github.com/pace/bricks/grpc" - "github.com/pace/bricks/http/security" - "github.com/pace/bricks/http/transport" - "github.com/pace/bricks/locale" - - "github.com/pace/bricks/maintenance/failover" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/backend/couchdb" "github.com/pace/bricks/backend/objstore" "github.com/pace/bricks/backend/postgres" "github.com/pace/bricks/backend/redis" + "github.com/pace/bricks/grpc" pacehttp "github.com/pace/bricks/http" "github.com/pace/bricks/http/oauth2" + "github.com/pace/bricks/http/security" + "github.com/pace/bricks/http/transport" + "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/errors" + "github.com/pace/bricks/maintenance/failover" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" _ "github.com/pace/bricks/maintenance/tracing" "github.com/pace/bricks/test/livetest" @@ -32,7 +31,7 @@ import ( simple "github.com/pace/bricks/tools/testserver/simple" ) -// pace lat/lon +// pace lat/lon. var ( lat = 49.012553 lon = 8.427087 @@ -51,8 +50,9 @@ func (*OauthBackend) IntrospectToken(ctx context.Context, token string) (*oauth2 type TestService struct{} -func (*TestService) GetTest(ctx context.Context, w simple.GetTestResponseWriter, r *simple.GetTestRequest) error { +func (*TestService) GetTest(ctx context.Context, _ simple.GetTestResponseWriter, _ *simple.GetTestRequest) error { log.Debug("Request in flight, this will wait 5 min....") + for t := 0; t < 360; t++ { select { case <-ctx.Done(): @@ -61,16 +61,19 @@ func (*TestService) GetTest(ctx context.Context, w simple.GetTestResponseWriter, time.Sleep(time.Second) } } + return nil } func main() { db := postgres.DefaultConnectionPool() rdb := redis.Client() + cdb, err := couchdb.DefaultDatabase() if err != nil { log.Fatal(err) } + _, err = objstore.Client() if err != nil { log.Fatal(err) @@ -80,15 +83,23 @@ func main() { if err != nil { log.Fatal(err) } - go ap.Run(log.WithContext(context.Background())) // nolint: errcheck + + go func() { + if err := ap.Run(log.WithContext(context.Background())); err != nil { + log.Println(err) + } + }() h := pacehttp.Router() + servicehealthcheck.RegisterHealthCheckFunc("fail-50", func(ctx context.Context) (r servicehealthcheck.HealthCheckResult) { if time.Now().Unix()%2 == 0 { panic("boom") } + r.Msg = "Foo" r.State = servicehealthcheck.Ok + return }) @@ -104,14 +115,17 @@ func main() { // do dummy database query cdb := db.WithContext(ctx) + var result struct { Calc int //nolint } + res, err := cdb.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("Calc failed") return } + log.Ctx(ctx).Debug().Int("rows_affected", res.RowsAffected()).Msg("Calc done") // do dummy redis query @@ -124,7 +138,10 @@ func main() { // do dummy call to external service log.Ctx(ctx).Debug().Msg("Test before JSON") w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"street":"Haid-und-Neu-Straße 18, 76131 Karlsruhe", "sunset": "%s"}`, fetchSunsetandSunrise(ctx)) + + if _, err := fmt.Fprintf(w, `{"street":"Haid-und-Neu-Straße 18, 76131 Karlsruhe", "sunset": "%s"}`, fetchSunsetandSunrise(ctx)); err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("Failed writing message") + } }) h.HandleFunc("/grpc", func(rw http.ResponseWriter, r *http.Request) { @@ -134,11 +151,17 @@ func main() { if err != nil { log.Fatalf("did not connect: %s", err) } - defer conn.Close() + + defer func() { + if err := conn.Close(); err != nil { + log.Printf("Failed closing connection: %v", err) + } + }() ctx = security.ContextWithToken(ctx, security.TokenString("test")) c := math.NewMathServiceClient(conn) + o, err := c.Add(ctx, &math.Input{ A: 1, B: 23, @@ -147,20 +170,21 @@ func main() { log.Ctx(ctx).Debug().Err(err).Msg("failed to add") return } - log.Ctx(ctx).Info().Msgf("C: %d", o.C) + + log.Ctx(ctx).Info().Msgf("C: %d", o.GetC()) ctx = locale.WithLocale(ctx, locale.NewLocale("fr-CH", "Europe/Paris")) _, err = c.Add(ctx, &math.Input{}) if err != nil { - log.Ctx(ctx).Debug().Err(err).Msg("failed to substract") + log.Ctx(ctx).Debug().Err(err).Msg("failed to add") return } if r.URL.Query().Get("error") != "" { - _, err = c.Substract(ctx, &math.Input{}) + _, err = c.Subtract(ctx, &math.Input{}) if err != nil { - log.Ctx(ctx).Debug().Err(err).Msg("failed to substract") + log.Ctx(ctx).Debug().Err(err).Msg("failed to subtract") return } } @@ -171,20 +195,28 @@ func main() { if row.Err != nil { log.Println(err) w.WriteHeader(http.StatusInternalServerError) + return } + var doc interface{} - row.ScanDoc(&doc) // nolint: errcheck + + if err := row.ScanDoc(&doc); err != nil { + log.Printf("Failed scanning document: %v", err) + } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(doc) // nolint: errcheck + + if err := json.NewEncoder(w).Encode(doc); err != nil { + log.Printf("Failed encoding document: %v", err) + } }) h.HandleFunc("/panic", func(w http.ResponseWriter, r *http.Request) { go func() { defer errors.HandleWithCtx(r.Context(), "Some worker") - panic(fmt.Errorf("Something went wrong %d - times", 100)) + panic(fmt.Errorf("something went wrong %d - times", 100)) }() panic("Test for sentry") @@ -198,7 +230,7 @@ func main() { // Test OAuth // // This middleware is configured against an Oauth application dummy - m := oauth2.NewMiddleware(new(OauthBackend)) // nolint: staticcheck + m := oauth2.NewMiddleware(new(OauthBackend)) //nolint:staticcheck sr := h.PathPrefix("/test").Subrouter() sr.Use(m.Handler) @@ -207,29 +239,39 @@ func main() { // // curl -H "Authorization: Bearer 83142f1b767e910e78ba2d554b6708c371f053d13d6075bcc39766853a932253" localhost:3000/test/auth sr.HandleFunc("/oauth", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Oauth test successful.\n") + if _, err := fmt.Fprintf(w, "Oauth test successful.\n"); err != nil { + log.Logger().Warn().Err(err).Msg("Failed testing OAuth") + } }) s := pacehttp.Server(h) log.Logger().Info().Str("addr", s.Addr).Msg("Starting testserver ...") - // nolint:errcheck - go livetest.Test(context.Background(), []livetest.TestFunc{ - func(t *livetest.T) { - t.Log("Test /test query") - - resp, err := http.Get("http://localhost:3000/test") - if err != nil { - t.Error(err) - t.Fail() - return - } - if resp.StatusCode != 200 { - t.Logf("Received status code: %d", resp.StatusCode) - t.Fail() - } - }, - }) + go func() { + if err := livetest.Test(context.Background(), []livetest.TestFunc{ + func(t *livetest.T) { + t.Log("Test /test query") + + resp, err := http.Get("http://localhost:3000/test") + if err != nil { + t.Error(err) + t.Fail() + return + } + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + t.Logf("Received status code: %d", resp.StatusCode) + t.Fail() + } + }, + }); err != nil { + log.Logger().Warn().Err(err).Msg("Failure during livetest") + } + }() log.Fatal(s.ListenAndServe()) } @@ -244,7 +286,8 @@ func fetchSunsetandSunrise(ctx context.Context) string { span.SetData("lon", lon) url := fmt.Sprintf("https://api.sunrise-sunset.org/json?lat=%f&lng=%f&date=today", lat, lon) - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { log.Fatal(err) } @@ -252,11 +295,17 @@ func fetchSunsetandSunrise(ctx context.Context) string { c := &http.Client{ Transport: transport.NewDefaultTransportChain(), } + resp, err := c.Do(req) if err != nil { log.Fatal(err) } - defer resp.Body.Close() + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() var r struct { Results struct { @@ -264,8 +313,7 @@ func fetchSunsetandSunrise(ctx context.Context) string { } `json:"results"` } - err = json.NewDecoder(resp.Body).Decode(&r) - if err != nil { + if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { log.Fatal(err) } @@ -273,8 +321,10 @@ func fetchSunsetandSunrise(ctx context.Context) string { if err != nil { log.Fatal(err) } + sunset = sunset.Local() log.Ctx(ctx).Debug().Time("sunset", sunset).Str("str", r.Results.Sunset).Msg("Parsed sunset time") + return sunset.String() } diff --git a/tools/testserver/math/math.proto b/tools/testserver/math/math.proto index 4970e0fcc..549b1612d 100644 --- a/tools/testserver/math/math.proto +++ b/tools/testserver/math/math.proto @@ -15,5 +15,5 @@ message Output { service MathService { rpc Add(Input) returns (Output); - rpc Substract(Input) returns (Output); + rpc Subtract(Input) returns (Output); } \ No newline at end of file diff --git a/tools/testserver/simple/open-api.go b/tools/testserver/simple/open-api.go index e9db7d885..10626fcb2 100644 --- a/tools/testserver/simple/open-api.go +++ b/tools/testserver/simple/open-api.go @@ -59,7 +59,7 @@ func GetTestHandler(service GetTestHandlerService) http.Handler { /* GetTestResponseWriter is a standard http.ResponseWriter extended with methods -to generate the respective responses easily +to generate the respective responses easily. */ type GetTestResponseWriter interface { http.ResponseWriter @@ -69,7 +69,7 @@ type getTestResponseWriter struct { http.ResponseWriter } -// OK responds with empty response (HTTP code 200) +// OK responds with empty response (HTTP code 200). func (w *getTestResponseWriter) OK() { w.Header().Set("Content-Type", "application/vnd.api+json") w.WriteHeader(200) @@ -77,13 +77,13 @@ func (w *getTestResponseWriter) OK() { /* GetTestRequest is a standard http.Request extended with the -un-marshaled content object +un-marshaled content object. */ type GetTestRequest struct { Request *http.Request `valid:"-"` } -// Service interface for GetTestHandler handler +// Service interface for GetTestHandler handler. type GetTestHandlerService interface { // GetTest Test GetTest(context.Context, GetTestResponseWriter, *GetTestRequest) error diff --git a/tools/testserver/simplemath/main.go b/tools/testserver/simplemath/main.go index dc292d6cd..fd6524031 100644 --- a/tools/testserver/simplemath/main.go +++ b/tools/testserver/simplemath/main.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/getsentry/sentry-go" + "github.com/pace/bricks/grpc" "github.com/pace/bricks/http/security" "github.com/pace/bricks/locale" @@ -25,6 +26,7 @@ func (*GrpcAuthBackend) AuthorizeUnary(ctx context.Context) (context.Context, er } else { return nil, fmt.Errorf("unauthenticated") } + return ctx, nil } @@ -36,18 +38,21 @@ func (*SimpleMathServer) Add(ctx context.Context, i *math.Input) (*math.Output, if loc, ok := locale.FromCtx(ctx); ok { log.Ctx(ctx).Debug().Msgf("Locale: %q", loc.Serialize()) } + span := sentry.SpanFromContext(ctx) if span != nil { log.Ctx(ctx).Debug().Msgf("Span: %q", span.Name) } var o math.Output - o.C = i.A + i.B - log.Ctx(ctx).Debug().Msgf("A: %d + B: %d = C: %d", i.A, i.B, o.C) + + o.C = i.GetA() + i.GetB() + log.Ctx(ctx).Debug().Msgf("A: %d + B: %d = C: %d", i.GetA(), i.GetB(), o.GetC()) + return &o, nil } -func (*SimpleMathServer) Substract(ctx context.Context, i *math.Input) (*math.Output, error) { +func (*SimpleMathServer) Subtract(ctx context.Context, i *math.Input) (*math.Output, error) { panic("not implemented") } @@ -56,8 +61,7 @@ func main() { gs := grpc.Server(&GrpcAuthBackend{}) math.RegisterMathServiceServer(gs, ms) - err := grpc.ListenAndServe(gs) - if err != nil { + if err := grpc.ListenAndServe(gs); err != nil { log.Fatal(err) } }