Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ type SingleResultInterface interface {
Err() error
}

// FindResultInterface defines the interface for find query results
type FindResultInterface interface {
All(dest any) error
Err() error
}

// UpdateResultInterface defines the interface for update operation results
type UpdateResultInterface interface {
MatchedCount() int64
Expand All @@ -26,7 +32,7 @@ type UpdateResultInterface interface {
type DatabaseInterface interface {
GetTimeout() time.Duration
Ping(context.Context) error
Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error)
Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface
FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface
UpdateOne(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error)
Disconnect(ctx context.Context) error
Expand Down
93 changes: 81 additions & 12 deletions pkg/database/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"encoding/json"
"fmt"
"time"

Expand All @@ -13,8 +14,8 @@ type MockDatabase struct {
// PingFunc allows customizing Ping behavior
PingFunc func(ctx context.Context) error

// FindFunc allows customizing Find behavior
FindFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error)
// FindFunc allows customizing Find behavior - returns a FindResultInterface
FindFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface

// FindOneFunc allows customizing FindOne behavior - returns a SingleResultInterface
FindOneFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface
Expand Down Expand Up @@ -49,6 +50,28 @@ type MockUpdateResult struct {
upsertedID any
}

// MockFindResult implements FindResultInterface for testing
type MockFindResult struct {
results any
err error
}

// All decodes all results into dest
func (m *MockFindResult) All(dest any) error {
if m.err != nil {
return m.err
}
if m.results == nil {
return nil
}
return copySliceResult(m.results, dest)
}

// Err returns any error
func (m *MockFindResult) Err() error {
return m.err
}

// MatchedCount returns the number of documents matched
func (m *MockUpdateResult) MatchedCount() int64 {
return m.matchedCount
Expand Down Expand Up @@ -155,8 +178,8 @@ func NewMockDatabase() *MockDatabase {
PingFunc: func(ctx context.Context) error {
return nil
},
FindFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) {
return []any{}, nil
FindFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface {
return &MockFindResult{results: []any{}, err: nil}
},
FindOneFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface {
return &MockSingleResult{result: nil, err: fmt.Errorf("no document found")}
Expand Down Expand Up @@ -204,7 +227,7 @@ func (m *MockDatabase) GetTimeout() time.Duration {
}

// Find implements DatabaseInterface
func (m *MockDatabase) Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) {
func (m *MockDatabase) Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface {
m.FindCalls = append(m.FindCalls, FindCall{
Ctx: ctx,
Db: db,
Expand All @@ -213,18 +236,34 @@ func (m *MockDatabase) Find(ctx context.Context, db string, collection string, f
Opts: opts,
})

var result any
var err error

// Check if there's a queued response
if len(m.FindQueue) > 0 {
response := m.FindQueue[0]
m.FindQueue = m.FindQueue[1:]
return response.Result, response.Err
result = response.Result
err = response.Err
} else if m.FindFunc != nil {
// Fall back to FindFunc
return m.FindFunc(ctx, db, collection, filter, opts...)
} else {
result = []any{}
err = nil
}

// Fall back to FindFunc
if m.FindFunc != nil {
return m.FindFunc(ctx, db, collection, filter, opts...)
// Apply projection if present
if result != nil && err == nil {
for _, opt := range opts {
if proj, ok := opt.(*Projection); ok {
result = applyProjectionToSlice(result, proj)
break
}
}
}
return []any{}, nil

return &MockFindResult{results: result, err: err}
}

// FindOne implements DatabaseInterface
Expand Down Expand Up @@ -277,6 +316,17 @@ func copyResult(src any, dest any) error {
return bson.Unmarshal(bytes, dest)
}

// copySliceResult copies a slice from src into dest using JSON marshaling
// This is simpler than BSON for arrays at the top level
func copySliceResult(src any, dest any) error {
// Use standard JSON which handles arrays at top level correctly
bytes, err := json.Marshal(src)
if err != nil {
return err
}
return json.Unmarshal(bytes, dest)
}

// applyProjection filters fields from a result based on projection rules
func applyProjection(result any, proj *Projection) any {
if proj == nil || len(proj.fields) == 0 {
Expand Down Expand Up @@ -327,6 +377,25 @@ func applyProjection(result any, proj *Projection) any {
return result
}

// applyProjectionToSlice applies projection to a slice of results
func applyProjectionToSlice(results any, proj *Projection) any {
if proj == nil || len(proj.fields) == 0 {
return results
}

// Try to convert to slice
slice, ok := results.([]any)
if !ok {
return results
}

projected := make([]any, len(slice))
for i, item := range slice {
projected[i] = applyProjection(item, proj)
}
return projected
}

// UpdateOne implements DatabaseInterface
func (m *MockDatabase) UpdateOne(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) {
m.UpdateOneCalls = append(m.UpdateOneCalls, UpdateOneCall{
Expand Down Expand Up @@ -392,8 +461,8 @@ func (m *MockDatabase) ExpectPing(err error) *MockDatabase {

// ExpectFind sets up an expectation for Find
func (m *MockDatabase) ExpectFind(result any, err error) *MockDatabase {
m.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) {
return result, err
m.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface {
return &MockFindResult{results: result, err: err}
}
return m
}
Expand Down
63 changes: 33 additions & 30 deletions pkg/database/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ func TestMockDatabase(t *testing.T) {
}

// Test Find default (should return empty slice)
result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"id": 1})
var results []any
err = mock.Find(context.Background(), "testdb", "users", map[string]any{"id": 1}).All(&results)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
if result == nil {
if results == nil {
t.Error("expected non-nil result")
}

Expand Down Expand Up @@ -62,16 +63,12 @@ func TestMockDatabase(t *testing.T) {

mock.ExpectFind(expectedData, nil)

result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{})
var resultSlice []map[string]any
err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&resultSlice)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}

resultSlice, ok := result.([]map[string]any)
if !ok {
t.Fatal("expected result to be []map[string]any")
}

if len(resultSlice) != 2 {
t.Errorf("expected 2 results, got %d", len(resultSlice))
}
Expand Down Expand Up @@ -137,37 +134,39 @@ func TestMockDatabase(t *testing.T) {
mock := NewMockDatabase()

// Custom function that returns different results based on filter
mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) {
mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface {
filterMap, ok := filter.(map[string]any)
if !ok {
return nil, errors.New("invalid filter")
return &MockFindResult{results: nil, err: errors.New("invalid filter")}
}

if status, ok := filterMap["status"]; ok && status == "active" {
return []map[string]any{
return &MockFindResult{results: []map[string]any{
{"id": 1, "status": "active"},
{"id": 2, "status": "active"},
}, nil
}, err: nil}
}

return []map[string]any{}, nil
return &MockFindResult{results: []map[string]any{}, err: nil}
}

// Test with active status
result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "active"})
var activeResults []map[string]any
err := mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "active"}).All(&activeResults)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
if len(result.([]map[string]any)) != 2 {
if len(activeResults) != 2 {
t.Errorf("expected 2 results for active users")
}

// Test with inactive status
result, err = mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "inactive"})
var inactiveResults []map[string]any
err = mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "inactive"}).All(&inactiveResults)
if err != nil {
t.Errorf("expected nil error, got %v", err)
}
if len(result.([]map[string]any)) != 0 {
if len(inactiveResults) != 0 {
t.Errorf("expected 0 results for inactive users")
}

Expand Down Expand Up @@ -247,41 +246,42 @@ func TestMockDatabaseSequentialCalls(t *testing.T) {
QueueFind(settings, nil)

// First call returns users
result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{})
var usersResult []map[string]any
err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&usersResult)
if err != nil {
t.Errorf("unexpected error on first call: %v", err)
}
usersResult := result1.([]map[string]any)
if len(usersResult) != 2 || usersResult[0]["name"] != "Alice" {
t.Error("first call should return users")
}

// Second call returns notifications
result2, err := mock.Find(context.Background(), "testdb", "notifications", map[string]any{})
var notificationsResult []map[string]any
err = mock.Find(context.Background(), "testdb", "notifications", map[string]any{}).All(&notificationsResult)
if err != nil {
t.Errorf("unexpected error on second call: %v", err)
}
notificationsResult := result2.([]map[string]any)
if len(notificationsResult) != 2 || notificationsResult[0]["message"] != "Hello" {
t.Error("second call should return notifications")
}

// Third call returns settings
result3, err := mock.Find(context.Background(), "testdb", "settings", map[string]any{})
var settingsResult []map[string]any
err = mock.Find(context.Background(), "testdb", "settings", map[string]any{}).All(&settingsResult)
if err != nil {
t.Errorf("unexpected error on third call: %v", err)
}
settingsResult := result3.([]map[string]any)
if len(settingsResult) != 1 || settingsResult[0]["key"] != "theme" {
t.Error("third call should return settings")
}

// Fourth call falls back to default behavior (empty slice)
result4, err := mock.Find(context.Background(), "testdb", "other", map[string]any{})
var otherResult []any
err = mock.Find(context.Background(), "testdb", "other", map[string]any{}).All(&otherResult)
if err != nil {
t.Errorf("unexpected error on fourth call: %v", err)
}
if len(result4.([]any)) != 0 {
if len(otherResult) != 0 {
t.Error("fourth call should return empty slice (default)")
}

Expand All @@ -300,26 +300,29 @@ func TestMockDatabaseSequentialCalls(t *testing.T) {
QueueFind([]map[string]any{{"id": 2}}, nil)

// First call succeeds
result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{})
var result1 []map[string]any
err := mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result1)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if len(result1.([]map[string]any)) != 1 {
if len(result1) != 1 {
t.Error("first call should return 1 result")
}

// Second call returns error
_, err = mock.Find(context.Background(), "testdb", "users", map[string]any{})
var result2 []map[string]any
err = mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result2)
if err == nil || err.Error() != "connection timeout" {
t.Errorf("expected 'connection timeout' error, got %v", err)
}

// Third call succeeds again
result3, err := mock.Find(context.Background(), "testdb", "users", map[string]any{})
var result3 []map[string]any
err = mock.Find(context.Background(), "testdb", "users", map[string]any{}).All(&result3)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if len(result3.([]map[string]any)) != 1 {
if len(result3) != 1 {
t.Error("third call should return 1 result")
}
})
Expand Down
Loading
Loading