diff --git a/README.md b/README.md index 512ef9d..c671d26 100644 --- a/README.md +++ b/README.md @@ -272,6 +272,8 @@ if err != nil { ## Testing +### Running Tests + Run the test suite: ```bash @@ -292,8 +294,141 @@ go test ./pkg/database -v # MongoDB tests go test ./pkg/database -run TestMongo + +# Mock tests +go test ./pkg/database -run TestMockDatabase ``` +### Mocking for Tests + +The package includes a complete mock implementation of the `DatabaseInterface` that allows you to control the behavior of database operations in your tests without needing a real database connection. + +#### Basic Mock Usage + +```go +import ( + "context" + "testing" + "github.com/uug-ai/database/pkg/database" +) + +func TestMyFunction(t *testing.T) { + // Create a new mock database + mock := database.NewMockDatabase() + + // Set up expectations for what the mock should return + expectedUser := map[string]any{ + "id": 1, + "name": "Alice", + "email": "alice@example.com", + } + mock.ExpectFindOne(expectedUser, nil) + + // Inject the mock into your Database instance + opts := database.NewMongoOptions(). + SetUri("mongodb://localhost"). + SetTimeout(5000). + Build() + + db, err := database.New(opts, mock) + if err != nil { + t.Fatalf("failed to create database: %v", err) + } + + // Use the database - it will use your mock + result, err := db.Client.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 1}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Verify the call was tracked + if len(mock.FindOneCalls) != 1 { + t.Errorf("expected 1 FindOne call, got %d", len(mock.FindOneCalls)) + } +} +``` + +#### Advanced Mock Features + +**Expect Multiple Results:** +```go +mock := database.NewMockDatabase() + +// Mock Find to return multiple documents +users := []map[string]any{ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"}, +} +mock.ExpectFind(users, nil) + +result, err := mock.Find(ctx, "testdb", "users", map[string]any{}) +// result contains the mocked users +``` + +**Expect Errors:** +```go +mock := database.NewMockDatabase() + +// Mock a connection error +mock.ExpectPing(errors.New("connection failed")) + +err := mock.Ping(ctx) +// err will be "connection failed" +``` + +**Custom Behavior:** +```go +mock := database.NewMockDatabase() + +// Define custom logic based on input +mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + filterMap := filter.(map[string]any) + + if filterMap["status"] == "active" { + return []map[string]any{{"id": 1, "status": "active"}}, nil + } + + return []map[string]any{}, nil +} +``` + +**Track Call History:** +```go +mock := database.NewMockDatabase() + +// Make some calls +mock.Find(ctx, "testdb", "users", map[string]any{}) +mock.FindOne(ctx, "testdb", "users", map[string]any{"id": 1}) + +// Verify the calls +if len(mock.FindCalls) != 1 { + t.Error("expected 1 Find call") +} + +if mock.FindCalls[0].Collection != "users" { + t.Error("expected collection to be 'users'") +} + +// Reset call history for the next test +mock.Reset() +``` + +#### Mock API + +The `MockDatabase` type provides: + +- **`NewMockDatabase()`**: Creates a new mock with sensible defaults +- **`ExpectPing(err error)`**: Set expected Ping behavior +- **`ExpectFind(result any, err error)`**: Set expected Find behavior +- **`ExpectFindOne(result any, err error)`**: Set expected FindOne behavior +- **`PingFunc`**: Custom function for Ping behavior +- **`FindFunc`**: Custom function for Find behavior +- **`FindOneFunc`**: Custom function for FindOne behavior +- **`PingCalls`**: Slice of all Ping calls made +- **`FindCalls`**: Slice of all Find calls made +- **`FindOneCalls`**: Slice of all FindOne calls made +- **`Reset()`**: Clear all call history + ## OpenTelemetry Integration This package includes built-in OpenTelemetry instrumentation for MongoDB operations: diff --git a/go.mod b/go.mod index 564550b..bd7c4c5 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24.10 require ( github.com/go-playground/validator/v10 v10.30.1 + github.com/uug-ai/models v1.2.26 go.mongodb.org/mongo-driver v1.17.6 go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.64.0 ) diff --git a/go.sum b/go.sum index 6910971..b1976e0 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/uug-ai/models v1.2.26 h1:gHqq/+HT7D9EXEUpgLJVWbfjC+CwYmRHBJoRsMZwJfI= +github.com/uug-ai/models v1.2.26/go.mod h1:0EHI6EKF/f2J1iXmFuPFuZZ2yv9Q6kphqcS8wzHYGd8= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= diff --git a/pkg/database/database.go b/pkg/database/database.go index 36ee57d..5867241 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -1,9 +1,15 @@ package database -import "github.com/go-playground/validator/v10" +import ( + "context" + + "github.com/go-playground/validator/v10" +) type DatabaseInterface interface { - Ping() error + Ping(context.Context) error + Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) + FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) } // Database represents a database client instance diff --git a/pkg/database/mock.go b/pkg/database/mock.go new file mode 100644 index 0000000..2a5ce11 --- /dev/null +++ b/pkg/database/mock.go @@ -0,0 +1,207 @@ +package database + +import ( + "context" + "fmt" +) + +// MockDatabase is a mock implementation of DatabaseInterface for testing +type MockDatabase struct { + // PingFunc allows customizing Ping behavior + PingFunc func(ctx context.Context) error + + // FindFunc allows customizing Find behavior + FindFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) + + // FindOneFunc allows customizing FindOne behavior + FindOneFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) + + // Sequential response queues for multiple calls + PingQueue []PingResponse + FindQueue []FindResponse + FindOneQueue []FindOneResponse + + // Call tracking + PingCalls []PingCall + FindCalls []FindCall + FindOneCalls []FindOneCall +} + +// PingResponse represents a queued response for Ping +type PingResponse struct { + Err error +} + +// FindResponse represents a queued response for Find +type FindResponse struct { + Result any + Err error +} + +// FindOneResponse represents a queued response for FindOne +type FindOneResponse struct { + Result any + Err error +} + +// PingCall records a call to Ping +type PingCall struct { + Ctx context.Context +} + +// FindCall records a call to Find +type FindCall struct { + Ctx context.Context + Db string + Collection string + Filter any + Opts []any +} + +// FindOneCall records a call to FindOne +type FindOneCall struct { + Ctx context.Context + Db string + Collection string + Filter any + Opts []any +} + +// NewMockDatabase creates a new MockDatabase with sensible defaults +func NewMockDatabase() *MockDatabase { + return &MockDatabase{ + PingFunc: func(ctx context.Context) error { + return nil + }, + FindFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + return []any{}, nil + }, + FindOneFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + return nil, fmt.Errorf("no document found") + }, + PingCalls: []PingCall{}, + FindCalls: []FindCall{}, + FindOneCalls: []FindOneCall{}, + PingQueue: []PingResponse{}, + FindQueue: []FindResponse{}, + FindOneQueue: []FindOneResponse{}, + } +} + +// Ping implements DatabaseInterface +func (m *MockDatabase) Ping(ctx context.Context) error { + m.PingCalls = append(m.PingCalls, PingCall{Ctx: ctx}) + + // Check if there's a queued response + if len(m.PingQueue) > 0 { + response := m.PingQueue[0] + m.PingQueue = m.PingQueue[1:] + return response.Err + } + + // Fall back to PingFunc + if m.PingFunc != nil { + return m.PingFunc(ctx) + } + return nil +} + +// Find implements DatabaseInterface +func (m *MockDatabase) Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + m.FindCalls = append(m.FindCalls, FindCall{ + Ctx: ctx, + Db: db, + Collection: collection, + Filter: filter, + Opts: opts, + }) + + // Check if there's a queued response + if len(m.FindQueue) > 0 { + response := m.FindQueue[0] + m.FindQueue = m.FindQueue[1:] + return response.Result, response.Err + } + + // Fall back to FindFunc + if m.FindFunc != nil { + return m.FindFunc(ctx, db, collection, filter, opts...) + } + return []any{}, nil +} + +// FindOne implements DatabaseInterface +func (m *MockDatabase) FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + m.FindOneCalls = append(m.FindOneCalls, FindOneCall{ + Ctx: ctx, + Db: db, + Collection: collection, + Filter: filter, + Opts: opts, + }) + + // Check if there's a queued response + if len(m.FindOneQueue) > 0 { + response := m.FindOneQueue[0] + m.FindOneQueue = m.FindOneQueue[1:] + return response.Result, response.Err + } + + // Fall back to FindOneFunc + if m.FindOneFunc != nil { + return m.FindOneFunc(ctx, db, collection, filter, opts...) + } + return nil, fmt.Errorf("no document found") +} + +// Reset clears all recorded calls +func (m *MockDatabase) Reset() { + m.PingCalls = []PingCall{} + m.FindCalls = []FindCall{} + m.FindOneCalls = []FindOneCall{} + m.PingQueue = []PingResponse{} + m.FindQueue = []FindResponse{} + m.FindOneQueue = []FindOneResponse{} +} + +// ExpectPing sets up an expectation for Ping +func (m *MockDatabase) ExpectPing(err error) *MockDatabase { + m.PingFunc = func(ctx context.Context) error { + return err + } + return m +} + +// ExpectFind sets up an expectation for Find +func (m *MockDatabase) ExpectFind(result any, err error) *MockDatabase { + m.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + return result, err + } + return m +} + +// ExpectFindOne sets up an expectation for FindOne +func (m *MockDatabase) ExpectFindOne(result any, err error) *MockDatabase { + m.FindOneFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + return result, err + } + return m +} + +// QueuePing adds a Ping response to the queue for sequential calls +func (m *MockDatabase) QueuePing(err error) *MockDatabase { + m.PingQueue = append(m.PingQueue, PingResponse{Err: err}) + return m +} + +// QueueFind adds a Find response to the queue for sequential calls +func (m *MockDatabase) QueueFind(result any, err error) *MockDatabase { + m.FindQueue = append(m.FindQueue, FindResponse{Result: result, Err: err}) + return m +} + +// QueueFindOne adds a FindOne response to the queue for sequential calls +func (m *MockDatabase) QueueFindOne(result any, err error) *MockDatabase { + m.FindOneQueue = append(m.FindOneQueue, FindOneResponse{Result: result, Err: err}) + return m +} diff --git a/pkg/database/mock_test.go b/pkg/database/mock_test.go new file mode 100644 index 0000000..3e51628 --- /dev/null +++ b/pkg/database/mock_test.go @@ -0,0 +1,377 @@ +package database + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestMockDatabase(t *testing.T) { + t.Run("DefaultBehavior", func(t *testing.T) { + mock := NewMockDatabase() + + // Test Ping default (should succeed) + err := mock.Ping(context.Background()) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + + // Test Find default (should return empty slice) + result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"id": 1}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if result == nil { + t.Error("expected non-nil result") + } + + // Test FindOne default (should return error) + result, err = mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 1}) + if err == nil { + t.Error("expected error, got nil") + } + if result != nil { + t.Error("expected nil result") + } + }) + + t.Run("ExpectPingError", func(t *testing.T) { + mock := NewMockDatabase() + expectedErr := errors.New("connection failed") + + mock.ExpectPing(expectedErr) + + err := mock.Ping(context.Background()) + if err != expectedErr { + t.Errorf("expected %v, got %v", expectedErr, err) + } + + // Verify call was tracked + if len(mock.PingCalls) != 1 { + t.Errorf("expected 1 ping call, got %d", len(mock.PingCalls)) + } + }) + + t.Run("ExpectFindWithResults", func(t *testing.T) { + mock := NewMockDatabase() + expectedData := []map[string]any{ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"}, + } + + mock.ExpectFind(expectedData, nil) + + result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + + resultSlice, ok := result.([]map[string]any) + if !ok { + t.Fatal("expected result to be []map[string]any") + } + + if len(resultSlice) != 2 { + t.Errorf("expected 2 results, got %d", len(resultSlice)) + } + + // Verify call tracking + if len(mock.FindCalls) != 1 { + t.Errorf("expected 1 find call, got %d", len(mock.FindCalls)) + } + if mock.FindCalls[0].Db != "testdb" { + t.Errorf("expected db 'testdb', got '%s'", mock.FindCalls[0].Db) + } + if mock.FindCalls[0].Collection != "users" { + t.Errorf("expected collection 'users', got '%s'", mock.FindCalls[0].Collection) + } + }) + + t.Run("ExpectFindOneWithResult", func(t *testing.T) { + mock := NewMockDatabase() + expectedUser := map[string]any{ + "id": 1, + "name": "Alice", + "email": "alice@example.com", + } + + mock.ExpectFindOne(expectedUser, nil) + + result, err := mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 1}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + + user, ok := result.(map[string]any) + if !ok { + t.Fatal("expected result to be map[string]any") + } + + if user["name"] != "Alice" { + t.Errorf("expected name 'Alice', got '%v'", user["name"]) + } + + // Verify call tracking + if len(mock.FindOneCalls) != 1 { + t.Errorf("expected 1 findOne call, got %d", len(mock.FindOneCalls)) + } + }) + + t.Run("ExpectFindOneNotFound", func(t *testing.T) { + mock := NewMockDatabase() + expectedErr := errors.New("document not found") + + mock.ExpectFindOne(nil, expectedErr) + + result, err := mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 999}) + if err != expectedErr { + t.Errorf("expected error '%v', got '%v'", expectedErr, err) + } + if result != nil { + t.Error("expected nil result") + } + }) + + t.Run("CustomFindFunction", func(t *testing.T) { + mock := NewMockDatabase() + + // Custom function that returns different results based on filter + mock.FindFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + filterMap, ok := filter.(map[string]any) + if !ok { + return nil, errors.New("invalid filter") + } + + if status, ok := filterMap["status"]; ok && status == "active" { + return []map[string]any{ + {"id": 1, "status": "active"}, + {"id": 2, "status": "active"}, + }, nil + } + + return []map[string]any{}, nil + } + + // Test with active status + result, err := mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "active"}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if len(result.([]map[string]any)) != 2 { + t.Errorf("expected 2 results for active users") + } + + // Test with inactive status + result, err = mock.Find(context.Background(), "testdb", "users", map[string]any{"status": "inactive"}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if len(result.([]map[string]any)) != 0 { + t.Errorf("expected 0 results for inactive users") + } + + // Verify both calls were tracked + if len(mock.FindCalls) != 2 { + t.Errorf("expected 2 find calls, got %d", len(mock.FindCalls)) + } + }) + + t.Run("ResetCallHistory", func(t *testing.T) { + mock := NewMockDatabase() + + // Make some calls + mock.Ping(context.Background()) + mock.Find(context.Background(), "testdb", "users", map[string]any{}) + mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 1}) + + if len(mock.PingCalls) != 1 || len(mock.FindCalls) != 1 || len(mock.FindOneCalls) != 1 { + t.Error("expected calls to be tracked") + } + + // Reset + mock.Reset() + + if len(mock.PingCalls) != 0 || len(mock.FindCalls) != 0 || len(mock.FindOneCalls) != 0 { + t.Error("expected all call history to be cleared") + } + }) + + t.Run("UseWithDatabase", func(t *testing.T) { + mock := NewMockDatabase() + mock.ExpectPing(nil) + + opts := NewMongoOptions(). + SetUri("mongodb://localhost"). + SetTimeout(5000). + Build() + + // Inject the mock as the database client + db, err := New(opts, mock) + if err != nil { + t.Fatalf("failed to create database with mock: %v", err) + } + + // Use the database with the mock + err = db.Client.Ping(context.Background()) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + + // Verify the mock was called + if len(mock.PingCalls) != 1 { + t.Errorf("expected 1 ping call on mock, got %d", len(mock.PingCalls)) + } + }) +} + +func TestMockDatabaseSequentialCalls(t *testing.T) { + t.Run("QueueMultipleFinds", func(t *testing.T) { + mock := NewMockDatabase() + + // Queue multiple responses + users := []map[string]any{ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"}, + } + notifications := []map[string]any{ + {"id": 1, "message": "Hello"}, + {"id": 2, "message": "World"}, + } + settings := []map[string]any{ + {"key": "theme", "value": "dark"}, + } + + mock.QueueFind(users, nil). + QueueFind(notifications, nil). + QueueFind(settings, nil) + + // First call returns users + result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("unexpected error on first call: %v", err) + } + usersResult := result1.([]map[string]any) + if len(usersResult) != 2 || usersResult[0]["name"] != "Alice" { + t.Error("first call should return users") + } + + // Second call returns notifications + result2, err := mock.Find(context.Background(), "testdb", "notifications", map[string]any{}) + if err != nil { + t.Errorf("unexpected error on second call: %v", err) + } + notificationsResult := result2.([]map[string]any) + if len(notificationsResult) != 2 || notificationsResult[0]["message"] != "Hello" { + t.Error("second call should return notifications") + } + + // Third call returns settings + result3, err := mock.Find(context.Background(), "testdb", "settings", map[string]any{}) + if err != nil { + t.Errorf("unexpected error on third call: %v", err) + } + settingsResult := result3.([]map[string]any) + if len(settingsResult) != 1 || settingsResult[0]["key"] != "theme" { + t.Error("third call should return settings") + } + + // Fourth call falls back to default behavior (empty slice) + result4, err := mock.Find(context.Background(), "testdb", "other", map[string]any{}) + if err != nil { + t.Errorf("unexpected error on fourth call: %v", err) + } + if len(result4.([]any)) != 0 { + t.Error("fourth call should return empty slice (default)") + } + + // Verify all calls were tracked + if len(mock.FindCalls) != 4 { + t.Errorf("expected 4 find calls, got %d", len(mock.FindCalls)) + } + }) + + t.Run("QueueWithErrors", func(t *testing.T) { + mock := NewMockDatabase() + + // Queue responses with errors + mock.QueueFind([]map[string]any{{"id": 1}}, nil). + QueueFind(nil, fmt.Errorf("connection timeout")). + QueueFind([]map[string]any{{"id": 2}}, nil) + + // First call succeeds + result1, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if len(result1.([]map[string]any)) != 1 { + t.Error("first call should return 1 result") + } + + // Second call returns error + _, err = mock.Find(context.Background(), "testdb", "users", map[string]any{}) + if err == nil || err.Error() != "connection timeout" { + t.Errorf("expected 'connection timeout' error, got %v", err) + } + + // Third call succeeds again + result3, err := mock.Find(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if len(result3.([]map[string]any)) != 1 { + t.Error("third call should return 1 result") + } + }) + + t.Run("QueueFindOne", func(t *testing.T) { + mock := NewMockDatabase() + + // Queue multiple FindOne responses + mock.QueueFindOne(map[string]any{"id": 1, "name": "Alice"}, nil). + QueueFindOne(map[string]any{"id": 2, "name": "Bob"}, nil). + QueueFindOne(nil, fmt.Errorf("not found")) + + // First call + result1, err := mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 1}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result1.(map[string]any)["name"] != "Alice" { + t.Error("first call should return Alice") + } + + // Second call + result2, err := mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 2}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result2.(map[string]any)["name"] != "Bob" { + t.Error("second call should return Bob") + } + + // Third call returns error + _, err = mock.FindOne(context.Background(), "testdb", "users", map[string]any{"id": 3}) + if err == nil || err.Error() != "not found" { + t.Errorf("expected 'not found' error, got %v", err) + } + }) + + t.Run("ResetClearsQueue", func(t *testing.T) { + mock := NewMockDatabase() + + // Queue responses + mock.QueueFind([]map[string]any{{"id": 1}}, nil). + QueueFindOne(map[string]any{"id": 1}, nil) + + // Reset should clear queues + mock.Reset() + + if len(mock.FindQueue) != 0 { + t.Error("FindQueue should be empty after Reset") + } + if len(mock.FindOneQueue) != 0 { + t.Error("FindOneQueue should be empty after Reset") + } + }) +} diff --git a/pkg/database/mongodb.go b/pkg/database/mongodb.go index a9cfd35..9129552 100644 --- a/pkg/database/mongodb.go +++ b/pkg/database/mongodb.go @@ -99,7 +99,7 @@ func (b *MongoOptionsBuilder) Build() *MongoOptions { // MongoClient wraps mongo.Client to implement DatabaseInterface type MongoClient struct { Client *mongo.Client - options *MongoOptions + Options *MongoOptions } // NewMongoClient creates a new MongoClient with the provided MongoDB settings @@ -123,7 +123,7 @@ func newMongoClientFromURI(ctx context.Context, options *MongoOptions) (Database client, err := mongo.Connect(ctx, opts) return &MongoClient{ Client: client, - options: options, + Options: options, }, err } @@ -164,13 +164,58 @@ func newMongoClientFromComponents(ctx context.Context, options *MongoOptions) (D client, err := mongo.Connect(ctx, clientOpts) return &MongoClient{ Client: client, - options: options, + Options: options, }, err } -func (m *MongoClient) Ping() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(m.options.Timeout)*time.Millisecond) - defer cancel() +func (m *MongoClient) Ping(ctx context.Context) error { err := m.Client.Ping(ctx, nil) return err } + +// Find executes a find query on the specified database and collection +func (m *MongoClient) Find(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + coll := m.Client.Database(db).Collection(collection) + + // Convert opts to mongo.FindOptions if provided + var findOpts []*moptions.FindOptions + for _, opt := range opts { + if fo, ok := opt.(*moptions.FindOptions); ok { + findOpts = append(findOpts, fo) + } + } + + cursor, err := coll.Find(ctx, filter, findOpts...) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var results []any + if err = cursor.All(ctx, &results); err != nil { + return nil, err + } + + return results, nil +} + +// FindOne executes a findOne query on the specified database and collection +func (m *MongoClient) FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) (any, error) { + coll := m.Client.Database(db).Collection(collection) + + // Convert opts to mongo.FindOneOptions if provided + var findOneOpts []*moptions.FindOneOptions + for _, opt := range opts { + if fo, ok := opt.(*moptions.FindOneOptions); ok { + findOneOpts = append(findOneOpts, fo) + } + } + + var result any + err := coll.FindOne(ctx, filter, findOneOpts...).Decode(&result) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/pkg/database/mongodb_test.go b/pkg/database/mongodb_test.go index b95ae0c..3efb29c 100644 --- a/pkg/database/mongodb_test.go +++ b/pkg/database/mongodb_test.go @@ -1,20 +1,14 @@ package database import ( + "context" "os" "testing" -) - -// MockDatabaseInterface is a mock implementation of DatabaseInterface for testing -type MockDatabaseInterface struct { - PingCalled bool - PingError error -} + "time" -func (m *MockDatabaseInterface) Ping() error { - m.PingCalled = true - return m.PingError -} + "github.com/uug-ai/models/pkg/models" + "go.mongodb.org/mongo-driver/bson" +) // TestMongoOptionsValidation tests the validation of MongoDB options func TestMongoOptionsValidation(t *testing.T) { @@ -158,10 +152,7 @@ func TestMongoOptionsValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := tt.buildOpts() - mockClient := &MockDatabaseInterface{} - - _, err := New(opts, mockClient) - + _, err := New(opts) if tt.expectError && err == nil { t.Errorf("expected validation error but got nil") } @@ -247,7 +238,7 @@ func TestMongodbLiveIntegration(t *testing.T) { mongodbUri := os.Getenv("MONGODB_URI") return NewMongoOptions(). SetUri(mongodbUri). - SetTimeout(2000). + SetTimeout(5000). Build() }, expectError: false, @@ -280,7 +271,10 @@ func TestMongodbLiveIntegration(t *testing.T) { t.Fatalf("failed to create database instance: %v", err) } - err = db.Client.Ping() + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(db.Options.Timeout)*time.Millisecond) + defer cancel() + + err = db.Client.Ping(ctx) if tt.expectError && err == nil { t.Errorf("expected ping error but got nil") } @@ -290,3 +284,57 @@ func TestMongodbLiveIntegration(t *testing.T) { }) } } + +func TestFindIntegration(t *testing.T) { + mongodbUri := os.Getenv("MONGODB_URI") + if mongodbUri == "" { + t.Skip("MONGODB_URI not set, skipping integration test") + } + + opts := NewMongoOptions(). + SetUri(mongodbUri). + SetTimeout(5000). + Build() + + db, err := New(opts) + if err != nil { + t.Fatalf("failed to create database instance: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(db.Options.Timeout)*time.Millisecond) + defer cancel() + + // Test Find with username filter + filter := map[string]any{"username": "cedricve"} + results, err := db.Client.Find(ctx, "Kerberos", "users", filter) + if err != nil { + t.Fatalf("Find failed: %v", err) + } + + // Validate results + resultSlice, ok := results.([]any) + if !ok { + t.Fatalf("expected results to be []any, got %T", results) + } + + if len(resultSlice) != 1 { + t.Fatalf("expected exactly 1 result for username 'cedricve', got %d", len(resultSlice)) + } + + // Marshal the result to User struct + resultBytes, err := bson.Marshal(resultSlice[0]) + if err != nil { + t.Fatalf("failed to marshal result: %v", err) + } + + var user models.User + err = bson.Unmarshal(resultBytes, &user) + if err != nil { + t.Fatalf("failed to unmarshal to User struct: %v", err) + } + + // Validate user fields + if user.Username != "cedricve" { + t.Errorf("expected username 'cedricve', got '%s'", user.Username) + } +}