diff --git a/.github/workflows/pr-description.yml b/.github/workflows/pr-description.yml index a453f8b..b349ec5 100644 --- a/.github/workflows/pr-description.yml +++ b/.github/workflows/pr-description.yml @@ -20,4 +20,5 @@ jobs: azure_openai_api_key: ${{ secrets.AZURE_OPENAI_API_KEY }} azure_openai_endpoint: ${{ secrets.AZURE_OPENAI_ENDPOINT }} azure_openai_version: ${{ secrets.AZURE_OPENAI_VERSION }} + openai_model: ${{ secrets.OPENAI_MODEL }} overwrite_description: true diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index fb1ff26..bba1ecf 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -15,6 +15,7 @@ env: jobs: security-scan: runs-on: ubuntu-latest + steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - name: Checkout repository @@ -55,4 +56,4 @@ jobs: #format: "github" #github-pat: ${{ secrets.TOKEN }} env: - TRIVY_SKIP_DB_UPDATE: true + TRIVY_SKIP_DB_UPDATE: true \ No newline at end of file diff --git a/pkg/database/database.go b/pkg/database/database.go index 784bd51..21980ae 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -35,6 +35,7 @@ type DatabaseInterface interface { Find(ctx context.Context, db string, collection string, filter any, opts ...any) FindResultInterface FindOne(ctx context.Context, db string, collection string, filter any, opts ...any) SingleResultInterface UpdateOne(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) + Count(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) Disconnect(ctx context.Context) error InsertOne(ctx context.Context, db string, collection string, document any, opts ...any) (any, error) InsertMany(ctx context.Context, db string, collection string, documents []any, opts ...any) (any, error) diff --git a/pkg/database/mock.go b/pkg/database/mock.go index d40caf6..191b31a 100644 --- a/pkg/database/mock.go +++ b/pkg/database/mock.go @@ -23,17 +23,22 @@ type MockDatabase struct { // UpdateOneFunc allows customizing UpdateOne behavior UpdateOneFunc func(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) + // CountFunc allows customizing Count behavior + CountFunc func(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) + // Sequential response queues for multiple calls PingQueue []PingResponse FindQueue []FindResponse FindOneQueue []FindOneResponse UpdateOneQueue []UpdateOneResponse + CountQueue []CountResponse // Call tracking PingCalls []PingCall FindCalls []FindCall FindOneCalls []FindOneCall UpdateOneCalls []UpdateOneCall + CountCalls []CountCall } // MockSingleResult implements SingleResultInterface for testing @@ -172,6 +177,21 @@ type UpdateOneCall struct { Opts []any } +// CountResponse represents a queued response for Count +type CountResponse struct { + Count int64 + Err error +} + +// CountCall records a call to Count +type CountCall struct { + Ctx context.Context + Db string + Collection string + Filter any + Opts []any +} + // NewMockDatabase creates a new MockDatabase with sensible defaults func NewMockDatabase() *MockDatabase { return &MockDatabase{ @@ -187,14 +207,19 @@ func NewMockDatabase() *MockDatabase { UpdateOneFunc: func(ctx context.Context, db string, collection string, filter any, update any, opts ...any) (UpdateResultInterface, error) { return &MockUpdateResult{matchedCount: 1, modifiedCount: 1}, nil }, + CountFunc: func(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) { + return 0, nil + }, PingCalls: []PingCall{}, FindCalls: []FindCall{}, FindOneCalls: []FindOneCall{}, UpdateOneCalls: []UpdateOneCall{}, + CountCalls: []CountCall{}, PingQueue: []PingResponse{}, FindQueue: []FindResponse{}, FindOneQueue: []FindOneResponse{}, UpdateOneQueue: []UpdateOneResponse{}, + CountQueue: []CountResponse{}, } } @@ -429,6 +454,30 @@ func (m *MockDatabase) UpdateOne(ctx context.Context, db string, collection stri return &MockUpdateResult{matchedCount: 1, modifiedCount: 1}, nil } +// Count implements DatabaseInterface +func (m *MockDatabase) Count(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) { + m.CountCalls = append(m.CountCalls, CountCall{ + Ctx: ctx, + Db: db, + Collection: collection, + Filter: filter, + Opts: opts, + }) + + // Check if there's a queued response + if len(m.CountQueue) > 0 { + response := m.CountQueue[0] + m.CountQueue = m.CountQueue[1:] + return response.Count, response.Err + } + + // Fall back to CountFunc + if m.CountFunc != nil { + return m.CountFunc(ctx, db, collection, filter, opts...) + } + return 0, nil +} + // InsertOne implements DatabaseInterface func (m *MockDatabase) InsertOne(ctx context.Context, db string, collection string, document any, opts ...any) (any, error) { return nil, fmt.Errorf("InsertOne not implemented in MockDatabase") @@ -445,10 +494,12 @@ func (m *MockDatabase) Reset() { m.FindCalls = []FindCall{} m.FindOneCalls = []FindOneCall{} m.UpdateOneCalls = []UpdateOneCall{} + m.CountCalls = []CountCall{} m.PingQueue = []PingResponse{} m.FindQueue = []FindResponse{} m.FindOneQueue = []FindOneResponse{} m.UpdateOneQueue = []UpdateOneResponse{} + m.CountQueue = []CountResponse{} } // ExpectPing sets up an expectation for Ping @@ -520,3 +571,17 @@ func (m *MockDatabase) QueueUpdateOne(matchedCount, modifiedCount, upsertedCount }) return m } + +// ExpectCount sets up an expectation for Count +func (m *MockDatabase) ExpectCount(count int64, err error) *MockDatabase { + m.CountFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) { + return count, err + } + return m +} + +// QueueCount adds a Count response to the queue for sequential calls +func (m *MockDatabase) QueueCount(count int64, err error) *MockDatabase { + m.CountQueue = append(m.CountQueue, CountResponse{Count: count, Err: err}) + return m +} diff --git a/pkg/database/mock_test.go b/pkg/database/mock_test.go index 9adfe57..bd74272 100644 --- a/pkg/database/mock_test.go +++ b/pkg/database/mock_test.go @@ -376,5 +376,160 @@ func TestMockDatabaseSequentialCalls(t *testing.T) { if len(mock.FindOneQueue) != 0 { t.Error("FindOneQueue should be empty after Reset") } + if len(mock.CountQueue) != 0 { + t.Error("CountQueue should be empty after Reset") + } + }) +} + +func TestMockDatabaseCount(t *testing.T) { + t.Run("DefaultBehavior", func(t *testing.T) { + mock := NewMockDatabase() + + count, err := mock.Count(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 0 { + t.Errorf("expected count 0, got %d", count) + } + + // Verify call tracking + if len(mock.CountCalls) != 1 { + t.Errorf("expected 1 count call, got %d", len(mock.CountCalls)) + } + if mock.CountCalls[0].Db != "testdb" { + t.Errorf("expected db 'testdb', got '%s'", mock.CountCalls[0].Db) + } + if mock.CountCalls[0].Collection != "users" { + t.Errorf("expected collection 'users', got '%s'", mock.CountCalls[0].Collection) + } + }) + + t.Run("ExpectCount", func(t *testing.T) { + mock := NewMockDatabase() + mock.ExpectCount(42, nil) + + count, err := mock.Count(context.Background(), "testdb", "users", map[string]any{"status": "active"}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 42 { + t.Errorf("expected count 42, got %d", count) + } + }) + + t.Run("ExpectCountWithError", func(t *testing.T) { + mock := NewMockDatabase() + expectedErr := fmt.Errorf("connection failed") + mock.ExpectCount(0, expectedErr) + + count, err := mock.Count(context.Background(), "testdb", "users", map[string]any{}) + if err != expectedErr { + t.Errorf("expected error '%v', got '%v'", expectedErr, err) + } + if count != 0 { + t.Errorf("expected count 0, got %d", count) + } + }) + + t.Run("QueueMultipleCounts", func(t *testing.T) { + mock := NewMockDatabase() + + mock.QueueCount(10, nil). + QueueCount(0, fmt.Errorf("timeout")). + QueueCount(25, nil) + + // First call + count, err := mock.Count(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 10 { + t.Errorf("expected count 10, got %d", count) + } + + // Second call returns error + count, err = mock.Count(context.Background(), "testdb", "users", map[string]any{}) + if err == nil || err.Error() != "timeout" { + t.Errorf("expected 'timeout' error, got %v", err) + } + if count != 0 { + t.Errorf("expected count 0, got %d", count) + } + + // Third call succeeds + count, err = mock.Count(context.Background(), "testdb", "orders", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 25 { + t.Errorf("expected count 25, got %d", count) + } + + // Fourth call falls back to default + count, err = mock.Count(context.Background(), "testdb", "other", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 0 { + t.Errorf("expected count 0 (default), got %d", count) + } + + // Verify all calls tracked + if len(mock.CountCalls) != 4 { + t.Errorf("expected 4 count calls, got %d", len(mock.CountCalls)) + } + }) + + t.Run("CustomCountFunc", func(t *testing.T) { + mock := NewMockDatabase() + + mock.CountFunc = func(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) { + if collection == "users" { + return 100, nil + } + return 0, fmt.Errorf("unknown collection") + } + + count, err := mock.Count(context.Background(), "testdb", "users", map[string]any{}) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if count != 100 { + t.Errorf("expected count 100, got %d", count) + } + + count, err = mock.Count(context.Background(), "testdb", "unknown", map[string]any{}) + if err == nil { + t.Error("expected error for unknown collection") + } + if count != 0 { + t.Errorf("expected count 0, got %d", count) + } + + if len(mock.CountCalls) != 2 { + t.Errorf("expected 2 count calls, got %d", len(mock.CountCalls)) + } + }) + + t.Run("ResetClearsCountState", func(t *testing.T) { + mock := NewMockDatabase() + + mock.QueueCount(5, nil) + mock.Count(context.Background(), "testdb", "users", map[string]any{}) + + if len(mock.CountCalls) != 1 { + t.Error("expected count call to be tracked") + } + + mock.Reset() + + if len(mock.CountCalls) != 0 { + t.Error("expected count calls to be cleared after reset") + } + if len(mock.CountQueue) != 0 { + t.Error("expected count queue to be cleared after reset") + } }) } diff --git a/pkg/database/mongodb.go b/pkg/database/mongodb.go index bd50651..768ccc5 100644 --- a/pkg/database/mongodb.go +++ b/pkg/database/mongodb.go @@ -368,6 +368,21 @@ func (ur *UpdateResult) UpsertedID() any { return ur.result.UpsertedID } +// Count returns the number of documents in the specified database and collection matching the filter. +// Supports *moptions.CountOptions in opts. +func (m *MongoClient) Count(ctx context.Context, db string, collection string, filter any, opts ...any) (int64, error) { + coll := m.Client.Database(db).Collection(collection) + + var countOpts []*moptions.CountOptions + for _, opt := range opts { + if co, ok := opt.(*moptions.CountOptions); ok { + countOpts = append(countOpts, co) + } + } + + return coll.CountDocuments(ctx, filter, countOpts...) +} + // UpdateOne executes an update query on a single document in the specified database and collection. // Returns an UpdateResult that provides access to matched, modified, and upserted counts. // Supports *moptions.UpdateOptions in opts.