From 78f8c0fdf7dfa2c0b81ba05db46b6f4c83629817 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Verstraeten?= Date: Wed, 4 Feb 2026 14:04:25 +0000 Subject: [PATCH] Add FindResultInterface and update Find methods for fluent API support --- pkg/database/database.go | 8 +++- pkg/database/mock.go | 93 +++++++++++++++++++++++++++++++----- pkg/database/mock_test.go | 63 ++++++++++++------------ pkg/database/mongodb.go | 46 +++++++++++++----- pkg/database/mongodb_test.go | 9 +--- 5 files changed, 156 insertions(+), 63 deletions(-) diff --git a/pkg/database/database.go b/pkg/database/database.go index 6d35262..784bd51 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -15,6 +15,12 @@ type SingleResultInterface interface { Err() error } +// FindResultInterface defines the interface for find query results +type FindResultInterface interface { + All(dest any) error + Err() error +} + // UpdateResultInterface defines the interface for update operation results type UpdateResultInterface interface { MatchedCount() int64 @@ -26,7 +32,7 @@ type UpdateResultInterface interface { type DatabaseInterface interface { GetTimeout() time.Duration Ping(context.Context) error - Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) + Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface UpdateOne(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) Disconnect(ctx context.Context) error diff --git a/pkg/database/mock.go b/pkg/database/mock.go index bd4236d..d40caf6 100644 --- a/pkg/database/mock.go +++ b/pkg/database/mock.go @@ -2,6 +2,7 @@ package database import ( "context" + "encoding/json" "fmt" "time" @@ -13,8 +14,8 @@ type MockDatabase struct { // PingFunc allows customizing Ping behavior PingFunc func(ctx context.Context) error - // FindFunc allows customizing Find behavior - FindFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) + // FindFunc allows customizing Find behavior - returns a FindResultInterface + FindFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface // FindOneFunc allows customizing FindOne behavior - returns a SingleResultInterface FindOneFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface @@ -49,6 +50,28 @@ type MockUpdateResult struct { upsertedID any } +// MockFindResult implements FindResultInterface for testing +type MockFindResult struct { + results any + err error +} + +// All decodes all results into dest +func (m *MockFindResult) All(dest any) error { + if m.err != nil { + return m.err + } + if m.results == nil { + return nil + } + return copySliceResult(m.results, dest) +} + +// Err returns any error +func (m *MockFindResult) Err() error { + return m.err +} + // MatchedCount returns the number of documents matched func (m *MockUpdateResult) MatchedCount() int64 { return m.matchedCount @@ -155,8 +178,8 @@ func NewMockDatabase() *MockDatabase { PingFunc: func(ctx context.Context) error { return nil }, - FindFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { - return []any{}, nil + FindFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface { + return &MockFindResult{results: []any{}, err: nil} }, FindOneFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface { return &MockSingleResult{result: nil, err: fmt.Errorf("no document found")} @@ -204,7 +227,7 @@ func (m *MockDatabase) GetTimeout() time.Duration { } // Find implements DatabaseInterface -func (m *MockDatabase) Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { +func (m *MockDatabase) Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface { m.FindCalls = append(m.FindCalls, FindCall{ Ctx: ctx, Db: db, @@ -213,18 +236,34 @@ func (m *MockDatabase) Find(ctx context.Context, db string, collection string, f Opts: opts, }) + var result any + var err error + // Check if there's a queued response if len(m.FindQueue) > 0 { response := m.FindQueue[0] m.FindQueue = m.FindQueue[1:] - return response.Result, response.Err + result = response.Result + err = response.Err + } else if m.FindFunc != nil { + // Fall back to FindFunc + return m.FindFunc(ctx, db, collection, filter, opts...) + } else { + result = []any{} + err = nil } - // Fall back to FindFunc - if m.FindFunc != nil { - return m.FindFunc(ctx, db, collection, filter, opts...) + // Apply projection if present + if result != nil && err == nil { + for _, opt := range opts { + if proj, ok := opt.(*Projection); ok { + result = applyProjectionToSlice(result, proj) + break + } + } } - return []any{}, nil + + return &MockFindResult{results: result, err: err} } // FindOne implements DatabaseInterface @@ -277,6 +316,17 @@ func copyResult(src any, dest any) error { return bson.Unmarshal(bytes, dest) } +// copySliceResult copies a slice from src into dest using JSON marshaling +// This is simpler than BSON for arrays at the top level +func copySliceResult(src any, dest any) error { + // Use standard JSON which handles arrays at top level correctly + bytes, err := json.Marshal(src) + if err != nil { + return err + } + return json.Unmarshal(bytes, dest) +} + // applyProjection filters fields from a result based on projection rules func applyProjection(result any, proj *Projection) any { if proj == nil || len(proj.fields) == 0 { @@ -327,6 +377,25 @@ func applyProjection(result any, proj *Projection) any { return result } +// applyProjectionToSlice applies projection to a slice of results +func applyProjectionToSlice(results any, proj *Projection) any { + if proj == nil || len(proj.fields) == 0 { + return results + } + + // Try to convert to slice + slice, ok := results.([]any) + if !ok { + return results + } + + projected := make([]any, len(slice)) + for i, item := range slice { + projected[i] = applyProjection(item, proj) + } + return projected +} + // UpdateOne implements DatabaseInterface func (m *MockDatabase) UpdateOne(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) { m.UpdateOneCalls = append(m.UpdateOneCalls, UpdateOneCall{ @@ -392,8 +461,8 @@ func (m *MockDatabase) ExpectPing(err error) *MockDatabase { // ExpectFind sets up an expectation for Find func (m *MockDatabase) ExpectFind(result any, err error) *MockDatabase { - m.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { - return result, err + m.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface { + return &MockFindResult{results: result, err: err} } return m } diff --git a/pkg/database/mock_test.go b/pkg/database/mock_test.go index 99d90eb..9adfe57 100644 --- a/pkg/database/mock_test.go +++ b/pkg/database/mock_test.go @@ -18,11 +18,12 @@ func TestMockDatabase(t *testing.T) { } // Test Find default (should return empty slice) - result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"id": 1}) + var results []any + err = mock.Find(context.Background(), "testdb", "users", map[string]any{"id": 1}).All(&results) if err != nil { t.Errorf("expected nil error, got %v", err) } - if result == nil { + if results == nil { t.Error("expected non-nil result") } @@ -62,16 +63,12 @@ func TestMockDatabase(t *testing.T) { mock.ExpectFind(expectedData, nil) - result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + var resultSlice []map[string]any + err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&resultSlice) if err != nil { t.Errorf("expected nil error, got %v", err) } - resultSlice, ok := result.([]map[string]any) - if !ok { - t.Fatal("expected result to be []map[string]any") - } - if len(resultSlice) != 2 { t.Errorf("expected 2 results, got %d", len(resultSlice)) } @@ -137,37 +134,39 @@ func TestMockDatabase(t *testing.T) { mock := NewMockDatabase() // Custom function that returns different results based on filter - mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface { filterMap, ok := filter.(map[string]any) if !ok { - return nil, errors.New("invalid filter") + return &MockFindResult{results: nil, err: errors.New("invalid filter")} } if status, ok := filterMap["status"]; ok && status == "active" { - return []map[string]any{ + return &MockFindResult{results: []map[string]any{ {"id": 1, "status": "active"}, {"id": 2, "status": "active"}, - }, nil + }, err: nil} } - return []map[string]any{}, nil + return &MockFindResult{results: []map[string]any{}, err: nil} } // Test with active status - result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "active"}) + var activeResults []map[string]any + err := mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "active"}).All(&activeResults) if err != nil { t.Errorf("expected nil error, got %v", err) } - if len(result.([]map[string]any)) != 2 { + if len(activeResults) != 2 { t.Errorf("expected 2 results for active users") } // Test with inactive status - result, err = mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "inactive"}) + var inactiveResults []map[string]any + err = mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "inactive"}).All(&inactiveResults) if err != nil { t.Errorf("expected nil error, got %v", err) } - if len(result.([]map[string]any)) != 0 { + if len(inactiveResults) != 0 { t.Errorf("expected 0 results for inactive users") } @@ -247,41 +246,42 @@ func TestMockDatabaseSequentialCalls(t *testing.T) { QueueFind(settings, nil) // First call returns users - result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + var usersResult []map[string]any + err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&usersResult) if err != nil { t.Errorf("unexpected error on first call: %v", err) } - usersResult := result1.([]map[string]any) if len(usersResult) != 2 || usersResult[0]["name"] != "Alice" { t.Error("first call should return users") } // Second call returns notifications - result2, err := mock.Find(context.Background(), "testdb", "notifications", map[string]any{}) + var notificationsResult []map[string]any + err = mock.Find(context.Background(), "testdb", "notifications", map[string]any{}).All(¬ificationsResult) if err != nil { t.Errorf("unexpected error on second call: %v", err) } - notificationsResult := result2.([]map[string]any) if len(notificationsResult) != 2 || notificationsResult[0]["message"] != "Hello" { t.Error("second call should return notifications") } // Third call returns settings - result3, err := mock.Find(context.Background(), "testdb", "settings", map[string]any{}) + var settingsResult []map[string]any + err = mock.Find(context.Background(), "testdb", "settings", map[string]any{}).All(&settingsResult) if err != nil { t.Errorf("unexpected error on third call: %v", err) } - settingsResult := result3.([]map[string]any) if len(settingsResult) != 1 || settingsResult[0]["key"] != "theme" { t.Error("third call should return settings") } // Fourth call falls back to default behavior (empty slice) - result4, err := mock.Find(context.Background(), "testdb", "other", map[string]any{}) + var otherResult []any + err = mock.Find(context.Background(), "testdb", "other", map[string]any{}).All(&otherResult) if err != nil { t.Errorf("unexpected error on fourth call: %v", err) } - if len(result4.([]any)) != 0 { + if len(otherResult) != 0 { t.Error("fourth call should return empty slice (default)") } @@ -300,26 +300,29 @@ func TestMockDatabaseSequentialCalls(t *testing.T) { QueueFind([]map[string]any{{"id": 2}}, nil) // First call succeeds - result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + var result1 []map[string]any + err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result1) if err != nil { t.Errorf("expected no error, got %v", err) } - if len(result1.([]map[string]any)) != 1 { + if len(result1) != 1 { t.Error("first call should return 1 result") } // Second call returns error - _, err = mock.Find(context.Background(), "testdb", "users", map[string]any{}) + var result2 []map[string]any + err = mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result2) if err == nil || err.Error() != "connection timeout" { t.Errorf("expected 'connection timeout' error, got %v", err) } // Third call succeeds again - result3, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + var result3 []map[string]any + err = mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result3) if err != nil { t.Errorf("expected no error, got %v", err) } - if len(result3.([]map[string]any)) != 1 { + if len(result3) != 1 { t.Error("third call should return 1 result") } }) diff --git a/pkg/database/mongodb.go b/pkg/database/mongodb.go index d17f6c5..bd50651 100644 --- a/pkg/database/mongodb.go +++ b/pkg/database/mongodb.go @@ -225,8 +225,35 @@ func (m *MongoClient) GetTimeout() time.Duration { return time.Duration(m.Options.Timeout) * time.Millisecond } -// Find executes a find query on the specified database and collection -func (m *MongoClient) Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { +// FindResult wraps a MongoDB cursor for fluent API usage +type FindResult struct { + cursor *mongo.Cursor + ctx context.Context + err error +} + +// All decodes all results into the provided destination slice. +// The dest parameter must be a pointer to a slice. +func (fr *FindResult) All(dest any) error { + if fr.err != nil { + return fr.err + } + if fr.cursor == nil { + return fr.err + } + defer fr.cursor.Close(fr.ctx) + return fr.cursor.All(fr.ctx, dest) +} + +// Err returns any error that occurred during the query. +func (fr *FindResult) Err() error { + return fr.err +} + +// Find executes a find query on the specified database and collection. +// Returns a FindResult that can be used with .All() for fluent decoding. +// Supports *moptions.FindOptions and *Projection in opts. +func (m *MongoClient) Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface { coll := m.Client.Database(db).Collection(collection) // Convert opts to mongo.FindOptions if provided @@ -234,21 +261,14 @@ func (m *MongoClient) Find(ctx context.Context, db string, collection string, fi for _, opt := range opts { if fo, ok := opt.(*moptions.FindOptions); ok { findOpts = append(findOpts, fo) + } else if proj, ok := opt.(*Projection); ok { + // Convert Projection to FindOptions with SetProjection + findOpts = append(findOpts, moptions.Find().SetProjection(proj.toBSON())) } } cursor, err := coll.Find(ctx, filter, findOpts...) - if err != nil { - return nil, err - } - defer cursor.Close(ctx) - - var results []any - if err = cursor.All(ctx, &results); err != nil { - return nil, err - } - - return results, nil + return &FindResult{cursor: cursor, ctx: ctx, err: err} } // FindOne executes a findOne query on the specified database and collection. diff --git a/pkg/database/mongodb_test.go b/pkg/database/mongodb_test.go index fb5ee17..d166e82 100644 --- a/pkg/database/mongodb_test.go +++ b/pkg/database/mongodb_test.go @@ -310,17 +310,12 @@ func TestFindIntegration(t *testing.T) { // Test Find with username filter filter := map[string]any{"username": "cedricve"} - results, err := db.Client.Find(ctx, "Kerberos", "users", filter) + var resultSlice []any + err = db.Client.Find(ctx, "Kerberos", "users", filter).All(&resultSlice) if err != nil { t.Fatalf("Find failed: %v", err) } - // Validate results - resultSlice, ok := results.([]any) - if !ok { - t.Fatalf("expected results to be []any, got %T", results) - } - if len(resultSlice) != 1 { t.Fatalf("expected exactly 1 result for username 'cedricve', got %d", len(resultSlice)) }