From e0481fa166afaf2f1290a8e669202680b494103b Mon Sep 17 00:00:00 2001 From: Gerrit Date: Fri, 23 Aug 2024 13:21:43 +0200 Subject: [PATCH] Small improvements. --- auditing/auditing.go | 33 +++--- auditing/meilisearch.go | 2 +- auditing/meilisearch_integration_test.go | 15 +-- auditing/timescaledb.go | 122 +++++++++++++++++------ auditing/timescaledb_integration_test.go | 15 +-- 5 files changed, 125 insertions(+), 62 deletions(-) diff --git a/auditing/auditing.go b/auditing/auditing.go index 424470a..d4e1ad7 100644 --- a/auditing/auditing.go +++ b/auditing/auditing.go @@ -1,6 +1,7 @@ package auditing import ( + "context" "log/slog" "os" "path/filepath" @@ -49,32 +50,32 @@ const ( const EntryFilterDefaultLimit int64 = 100 type Entry struct { - Id string `db:"-"` // filled by the auditing driver + Id string // filled by the auditing driver - Component string `db:"component"` - RequestId string `db:"rqid" json:"rqid"` - Type EntryType `db:"type"` - Timestamp time.Time `db:"timestamp"` + Component string + RequestId string `json:"rqid"` + Type EntryType + Timestamp time.Time - User string `db:"userid"` - Tenant string `db:"tenant"` + User string + Tenant string // For `EntryDetailHTTP` the HTTP method get, post, put, delete, ... // For `EntryDetailGRPC` unary, stream - Detail EntryDetail `db:"detail"` + Detail EntryDetail // e.g. Request, Response, Error, Opened, Close - Phase EntryPhase `db:"phase"` + Phase EntryPhase // For `EntryDetailHTTP` /api/v1/... // For `EntryDetailGRPC` /api.v1/... (the method name) - Path string `db:"path"` - ForwardedFor string `db:"forwardedfor"` - RemoteAddr string `db:"remoteaddr"` + Path string + ForwardedFor string + RemoteAddr string - Body any `db:"body"` // JSON, string or numbers - StatusCode int `db:"statuscode"` // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code + Body any // JSON, string or numbers + StatusCode int // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code // Internal errors - Error string `db:"error"` + Error string } func (e *Entry) prepareForNextPhase() { @@ -133,7 +134,7 @@ type Auditing interface { // Searches for entries matching the given filter. // By default only recent entries will be returned. // The returned entries will be sorted by timestamp in descending order. - Search(EntryFilter) ([]Entry, error) + Search(context.Context, EntryFilter) ([]Entry, error) } func defaultComponent() (string, error) { diff --git a/auditing/meilisearch.go b/auditing/meilisearch.go index df29b78..8cbcc5f 100644 --- a/auditing/meilisearch.go +++ b/auditing/meilisearch.go @@ -129,7 +129,7 @@ func (a *meiliAuditing) Index(entry Entry) error { return nil } -func (a *meiliAuditing) Search(filter EntryFilter) ([]Entry, error) { +func (a *meiliAuditing) Search(_ context.Context, filter EntryFilter) ([]Entry, error) { predicates := make([]string, 0) if filter.Component != "" { predicates = append(predicates, fmt.Sprintf("component = %q", filter.Component)) diff --git a/auditing/meilisearch_integration_test.go b/auditing/meilisearch_integration_test.go index 0d2f9d7..8d6d709 100644 --- a/auditing/meilisearch_integration_test.go +++ b/auditing/meilisearch_integration_test.go @@ -70,6 +70,7 @@ func StartMeilisearch(t testing.TB) (container testcontainers.Container, c *conn } func TestAuditing_Meilisearch(t *testing.T) { + ctx := context.Background() container, c, err := StartMeilisearch(t) require.NoError(t, err) defer func() { @@ -143,7 +144,7 @@ func TestAuditing_Meilisearch(t *testing.T) { { name: "no entries, no search results", t: func(t *testing.T, a Auditing) { - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Empty(t, entries) }, @@ -158,7 +159,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Body: "test", }) require.NoError(t, err) @@ -177,7 +178,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Len(t, entries, len(es)) @@ -187,7 +188,7 @@ func TestAuditing_Meilisearch(t *testing.T) { t.Errorf("diff (+got -want):\n %s", diff) } - entries, err = a.Search(EntryFilter{ + entries, err = a.Search(ctx, EntryFilter{ Body: "This", }) require.NoError(t, err) @@ -206,7 +207,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ RequestId: es[0].RequestId, }) require.NoError(t, err) @@ -234,7 +235,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Phase: EntryPhaseResponse, }) require.NoError(t, err) @@ -259,7 +260,7 @@ func TestAuditing_Meilisearch(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ // we want to run a phrase search as otherwise we return the other entries as well // https://www.meilisearch.com/docs/reference/api/search#phrase-search-2 Body: fmt.Sprintf("%q", es[0].Body.(string)), diff --git a/auditing/timescaledb.go b/auditing/timescaledb.go index 6b19833..98840ea 100644 --- a/auditing/timescaledb.go +++ b/auditing/timescaledb.go @@ -3,6 +3,7 @@ package auditing import ( "context" "database/sql" + "encoding/json" "fmt" "log/slog" "reflect" @@ -17,22 +18,50 @@ import ( _ "github.com/lib/pq" ) -type TimescaleDbConfig struct { - Host string - Port string - DB string - User string - Password string -} - -type timescaleAuditing struct { - component string - db *sqlx.DB - log *slog.Logger +type ( + TimescaleDbConfig struct { + Host string + Port string + DB string + User string + Password string + } + + timescaleAuditing struct { + component string + db *sqlx.DB + log *slog.Logger + + cols []string + vals []any + } + + // to keep the public interface free from field tags like "db" and "json" (as these might differ for different dbs) + // we introduce an internal type. unfortunately, this requires a conversion, which takes effort to maintain + timescaleEntry struct { + Component string `db:"component"` + RequestId string `db:"rqid" json:"rqid"` + Type EntryType `db:"type"` + Timestamp time.Time `db:"timestamp"` + User string `db:"userid"` + Tenant string `db:"tenant"` + Detail EntryDetail `db:"detail"` + Phase EntryPhase `db:"phase"` + Path string `db:"path"` + ForwardedFor string `db:"forwardedfor"` + RemoteAddr string `db:"remoteaddr"` + Body any `db:"body"` + StatusCode int `db:"statuscode"` + Error string `db:"error"` + } + + sqlCompOp string +) - cols []string - vals []any -} +const ( + equals sqlCompOp = "equals" + like sqlCompOp = "like" +) func NewTimescaleDB(c Config, tc TimescaleDbConfig) (Auditing, error) { if c.Component == "" { @@ -187,12 +216,15 @@ func (a *timescaleAuditing) Index(entry Entry) error { return err } + internalEntry, err := a.toInternal(entry) + if err != nil { + return fmt.Errorf("unable to convert audit trace to database entry: %w", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - entry.Id = entry.RequestId - - _, err = a.db.NamedExecContext(ctx, q, entry) + _, err = a.db.NamedExecContext(ctx, q, internalEntry) if err != nil { return fmt.Errorf("unable to index audit trace: %w", err) } @@ -200,18 +232,11 @@ func (a *timescaleAuditing) Index(entry Entry) error { return nil } -type compOp string - -const ( - equals compOp = "equals" - like compOp = "like" -) - -func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { +func (a *timescaleAuditing) Search(ctx context.Context, filter EntryFilter) ([]Entry, error) { var ( where []string values = map[string]interface{}{} - addFilter = func(field string, value any, op compOp) error { + addFilter = func(field string, value any, op sqlCompOp) error { if reflect.ValueOf(value).IsZero() { return nil } @@ -303,7 +328,7 @@ func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { return nil, err } - rows, err := a.db.NamedQueryContext(context.TODO(), q, values) // TODO: search needs a ctx! + rows, err := a.db.NamedQueryContext(ctx, q, values) if err != nil { return nil, err } @@ -312,17 +337,52 @@ func (a *timescaleAuditing) Search(filter EntryFilter) ([]Entry, error) { var entries []Entry for rows.Next() { - var e Entry + var e timescaleEntry err = rows.StructScan(&e) if err != nil { return nil, err } - e.Id = e.RequestId + entry, err := a.toExternal(e) + if err != nil { + return nil, fmt.Errorf("unable to convert entry: %w", err) + } - entries = append(entries, e) + entries = append(entries, entry) } return entries, nil } + +func (_ *timescaleAuditing) toInternal(e Entry) (*timescaleEntry, error) { + intermediate, err := json.Marshal(e) // nolint + if err != nil { + return nil, err + } + var internalEntry timescaleEntry + err = json.Unmarshal(intermediate, &internalEntry) // nolint + if err != nil { + return nil, err + } + + internalEntry.RequestId = e.RequestId + + return &internalEntry, nil +} + +func (_ *timescaleAuditing) toExternal(e timescaleEntry) (Entry, error) { + intermediate, err := json.Marshal(e) // nolint + if err != nil { + return Entry{}, err + } + var externalEntry Entry + err = json.Unmarshal(intermediate, &externalEntry) // nolint + if err != nil { + return Entry{}, err + } + + externalEntry.Id = e.RequestId + + return externalEntry, nil +} diff --git a/auditing/timescaledb_integration_test.go b/auditing/timescaledb_integration_test.go index 12e6635..f2abb36 100644 --- a/auditing/timescaledb_integration_test.go +++ b/auditing/timescaledb_integration_test.go @@ -20,6 +20,7 @@ import ( ) func TestAuditing_TimescaleDB(t *testing.T) { + ctx := context.Background() container, auditing := StartTimescaleDB(t, Config{ Log: slog.Default(), }) @@ -94,7 +95,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { { name: "no entries, no search results", t: func(t *testing.T, a Auditing) { - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Empty(t, entries) }, @@ -109,7 +110,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err = a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Body: "test", }) require.NoError(t, err) @@ -128,7 +129,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{}) + entries, err := a.Search(ctx, EntryFilter{}) require.NoError(t, err) assert.Len(t, entries, len(es)) @@ -138,7 +139,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { t.Errorf("diff (+got -want):\n %s", diff) } - entries, err = a.Search(EntryFilter{ + entries, err = a.Search(ctx, EntryFilter{ Body: "This", }) require.NoError(t, err) @@ -157,7 +158,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ RequestId: es[0].RequestId, }) require.NoError(t, err) @@ -185,7 +186,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { err := a.Flush() require.NoError(t, err) - entries, err := a.Search(EntryFilter{ + entries, err := a.Search(ctx, EntryFilter{ Phase: EntryPhaseResponse, }) require.NoError(t, err) @@ -210,7 +211,7 @@ func TestAuditing_TimescaleDB(t *testing.T) { // err := a.Flush() // require.NoError(t, err) - // entries, err := a.Search(EntryFilter{ + // entries, err := a.Search(ctx, EntryFilter{ // Body: fmt.Sprintf("%q", es[0].Body.(string)), // }) // require.NoError(t, err)