Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement audit backend for TimescaleDB. #151

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
46 changes: 40 additions & 6 deletions auditing/auditing-interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"time"

"connectrpc.com/connect"
"github.com/emicklei/go-restful/v3"
Expand Down Expand Up @@ -39,12 +40,17 @@ func UnaryServerInterceptor(a Auditing, logger *slog.Logger, shouldAudit func(fu
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
return nil, err
}
requestID = uuid.String()
}

childCtx := context.WithValue(ctx, rest.RequestIDKey, requestID)

auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Type: EntryTypeGRPC,
Detail: EntryDetailGRPCUnary,
Expand Down Expand Up @@ -97,7 +103,11 @@ func StreamServerInterceptor(a Auditing, logger *slog.Logger, shouldAudit func(f
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
return err
}
requestID = uuid.String()
}
childCtx := context.WithValue(ss.Context(), rest.RequestIDKey, requestID)
childSS := grpcServerStreamWithContext{
Expand All @@ -106,6 +116,7 @@ func StreamServerInterceptor(a Auditing, logger *slog.Logger, shouldAudit func(f
}

auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Detail: EntryDetailGRPCStream,
Path: info.FullMethod,
Expand Down Expand Up @@ -161,11 +172,16 @@ func (a auditingConnectInterceptor) WrapStreamingClient(next connect.StreamingCl
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
a.logger.Error("unable to generate uuid", "error", err)
}
requestID = uuid.String()
}
childCtx := context.WithValue(ctx, rest.RequestIDKey, requestID)

auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Detail: EntryDetailGRPCStream,
Path: s.Procedure,
Expand Down Expand Up @@ -210,11 +226,16 @@ func (a auditingConnectInterceptor) WrapStreamingHandler(next connect.StreamingH
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
return err
}
requestID = uuid.String()
}
childCtx := context.WithValue(ctx, rest.RequestIDKey, requestID)

auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Detail: EntryDetailGRPCStream,
Path: shc.Spec().Procedure,
Expand Down Expand Up @@ -273,11 +294,16 @@ func (i auditingConnectInterceptor) WrapUnary(next connect.UnaryFunc) connect.Un
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
return nil, err
}
requestID = uuid.String()
}
childCtx := context.WithValue(ctx, rest.RequestIDKey, requestID)

auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Detail: EntryDetailGRPCUnary,
Path: ar.Spec().Procedure,
Expand Down Expand Up @@ -378,9 +404,17 @@ func HttpFilter(a Auditing, logger *slog.Logger) (restful.FilterFunction, error)
requestID = str
}
if requestID == "" {
requestID = uuid.NewString()
uuid, err := uuid.NewV7()
if err != nil {
logger.Error("unable to generate uuid", "error", err)
_, _ = response.Write([]byte("unable to generate request uuid " + err.Error()))
response.WriteHeader(http.StatusInternalServerError)
return
}
requestID = uuid.String()
}
auditReqContext := Entry{
Timestamp: time.Now(),
RequestId: requestID,
Type: EntryTypeHTTP,
Detail: EntryDetail(r.Method),
Expand Down
53 changes: 30 additions & 23 deletions auditing/auditing.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package auditing

import (
"context"
"log/slog"
"os"
"path/filepath"
"time"
)

type Config struct {
Component string
URL string
APIKey string
IndexPrefix string
RotationInterval Interval
Keep int64
Log *slog.Logger
Component string
Log *slog.Logger
}

type Interval string
Expand Down Expand Up @@ -52,31 +50,31 @@ const (
const EntryFilterDefaultLimit int64 = 100

type Entry struct {
Id string // filled by the auditing driver
Component string
RequestId string `json:"rqid"`
Type EntryType
Timestamp time.Time
Id string `json:"-"` // filled by the auditing driver
Component string `json:"component"`
RequestId string `json:"rqid"`
Type EntryType `json:"type"`
Timestamp time.Time `json:"timestamp"`

User string
Tenant string
User string `json:"user"`
Tenant string `json:"tenant"`

// For `EntryDetailHTTP` the HTTP method get, post, put, delete, ...
// For `EntryDetailGRPC` unary, stream
Detail EntryDetail
Detail EntryDetail `json:"detail"`
// e.g. Request, Response, Error, Opened, Close
Phase EntryPhase
Phase EntryPhase `json:"phase"`
// For `EntryDetailHTTP` /api/v1/...
// For `EntryDetailGRPC` /api.v1/... (the method name)
Path string
ForwardedFor string
RemoteAddr string
Path string `json:"path"`
ForwardedFor string `json:"forwardedfor"`
RemoteAddr string `json:"remoteaddr"`

Body any // JSON, string or numbers
StatusCode int // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code
Body any `json:"body"` // JSON, string or numbers
StatusCode int `json:"statuscode"` // for `EntryDetailHTTP` the HTTP status code, for EntryDetailGRPC` the grpc status code

// Internal errors
Error error
Error error `json:"error"`
}

func (e *Entry) prepareForNextPhase() {
Expand Down Expand Up @@ -135,5 +133,14 @@ type Auditing interface {
// Searches for entries matching the given filter.
// By default only recent entries will be returned.
// The returned entries will be sorted by timestamp in descending order.
Search(EntryFilter) ([]Entry, error)
Search(context.Context, EntryFilter) ([]Entry, error)
}

func defaultComponent() (string, error) {
ex, err := os.Executable()
if err != nil {
return "", err
}

return filepath.Base(ex), nil
}
34 changes: 21 additions & 13 deletions auditing/meilisearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
Expand All @@ -17,6 +15,15 @@ import (
"github.com/meilisearch/meilisearch-go"
)

type MeilisearchConfig struct {
URL string
APIKey string

IndexPrefix string
RotationInterval Interval
Keep int64
}

type meiliAuditing struct {
component string
client *meilisearch.Client
Expand All @@ -39,32 +46,33 @@ const (
meiliIndexCreationWaitInterval = 100 * time.Millisecond
)

func New(c Config) (Auditing, error) {
func NewMeilisearch(c Config, mc MeilisearchConfig) (Auditing, error) {
if c.Component == "" {
ex, err := os.Executable()
component, err := defaultComponent()
if err != nil {
return nil, err
}
c.Component = filepath.Base(ex)

c.Component = component
}

client := meilisearch.NewClient(meilisearch.ClientConfig{
Host: c.URL,
APIKey: c.APIKey,
Host: mc.URL,
APIKey: mc.APIKey,
})
v, err := client.GetVersion()
if err != nil {
return nil, fmt.Errorf("unable to connect to meilisearch at:%s %w", c.URL, err)
return nil, fmt.Errorf("unable to connect to meilisearch at:%s %w", mc.URL, err)
}
c.Log.Info("meilisearch", "connected to", v, "index rotated", c.RotationInterval, "index keep", c.Keep)
c.Log.Info("meilisearch", "connected to", v, "index rotated", mc.RotationInterval, "index keep", mc.Keep)

a := &meiliAuditing{
component: c.Component,
client: client,
log: c.Log.WithGroup("auditing"),
indexPrefix: c.IndexPrefix,
rotationInterval: c.RotationInterval,
keep: c.Keep,
indexPrefix: mc.IndexPrefix,
rotationInterval: mc.RotationInterval,
keep: mc.Keep,
}
return a, nil
}
Expand Down Expand Up @@ -121,7 +129,7 @@ func (a *meiliAuditing) Index(entry Entry) error {
return nil
}

func (a *meiliAuditing) Search(filter EntryFilter) ([]Entry, error) {
func (a *meiliAuditing) Search(_ context.Context, filter EntryFilter) ([]Entry, error) {
predicates := make([]string, 0)
if filter.Component != "" {
predicates = append(predicates, fmt.Sprintf("component = %q", filter.Component))
Expand Down
20 changes: 11 additions & 9 deletions auditing/meilisearch_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func StartMeilisearch(t testing.TB) (container testcontainers.Container, c *conn
}

func TestAuditing_Meilisearch(t *testing.T) {
ctx := context.Background()
container, c, err := StartMeilisearch(t)
require.NoError(t, err)
defer func() {
Expand Down Expand Up @@ -143,7 +144,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
{
name: "no entries, no search results",
t: func(t *testing.T, a Auditing) {
entries, err := a.Search(EntryFilter{})
entries, err := a.Search(ctx, EntryFilter{})
require.NoError(t, err)
assert.Empty(t, entries)
},
Expand All @@ -158,7 +159,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
err = a.Flush()
require.NoError(t, err)

entries, err := a.Search(EntryFilter{
entries, err := a.Search(ctx, EntryFilter{
Body: "test",
})
require.NoError(t, err)
Expand All @@ -177,7 +178,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
err = a.Flush()
require.NoError(t, err)

entries, err := a.Search(EntryFilter{})
entries, err := a.Search(ctx, EntryFilter{})
require.NoError(t, err)
assert.Len(t, entries, len(es))

Expand All @@ -187,7 +188,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
t.Errorf("diff (+got -want):\n %s", diff)
}

entries, err = a.Search(EntryFilter{
entries, err = a.Search(ctx, EntryFilter{
Body: "This",
})
require.NoError(t, err)
Expand All @@ -206,7 +207,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
err = a.Flush()
require.NoError(t, err)

entries, err := a.Search(EntryFilter{
entries, err := a.Search(ctx, EntryFilter{
RequestId: es[0].RequestId,
})
require.NoError(t, err)
Expand Down Expand Up @@ -234,7 +235,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
err = a.Flush()
require.NoError(t, err)

entries, err := a.Search(EntryFilter{
entries, err := a.Search(ctx, EntryFilter{
Phase: EntryPhaseResponse,
})
require.NoError(t, err)
Expand All @@ -259,7 +260,7 @@ func TestAuditing_Meilisearch(t *testing.T) {
err = a.Flush()
require.NoError(t, err)

entries, err := a.Search(EntryFilter{
entries, err := a.Search(ctx, EntryFilter{
// we want to run a phrase search as otherwise we return the other entries as well
// https://www.meilisearch.com/docs/reference/api/search#phrase-search-2
Body: fmt.Sprintf("%q", es[0].Body.(string)),
Expand All @@ -277,10 +278,11 @@ func TestAuditing_Meilisearch(t *testing.T) {
tt := tt

t.Run(fmt.Sprintf("%d %s", i, tt.name), func(t *testing.T) {
a, err := New(Config{
a, err := NewMeilisearch(Config{
Log: slog.Default(),
}, MeilisearchConfig{
URL: c.Endpoint,
APIKey: c.Password,
Log: slog.Default(),
IndexPrefix: fmt.Sprintf("test-%d", i),
})
require.NoError(t, err)
Expand Down
Loading
Loading