diff --git a/src/pkg/parser/parser_pool.go b/src/pkg/parser/parser_pool.go new file mode 100644 index 0000000..1db610b --- /dev/null +++ b/src/pkg/parser/parser_pool.go @@ -0,0 +1,172 @@ +package parser + +import ( + "context" + "fmt" + "sync" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/specvital/core/domain" +) + +var ( + goParserPool sync.Pool + jsParserPool sync.Pool + tsParserPool sync.Pool +) + +func getParserPool(lang domain.Language) *sync.Pool { + switch lang { + case domain.LanguageGo: + return &goParserPool + case domain.LanguageJavaScript: + return &jsParserPool + default: + return &tsParserPool + } +} + +func getPooledParser(lang domain.Language) *sitter.Parser { + pool := getParserPool(lang) + + if p := pool.Get(); p != nil { + if parser, ok := p.(*sitter.Parser); ok { + return parser + } + } + + initLanguages() + parser := sitter.NewParser() + parser.SetLanguage(getSitterLanguage(lang)) + return parser +} + +func putPooledParser(lang domain.Language, parser *sitter.Parser) { + if parser == nil { + return + } + pool := getParserPool(lang) + pool.Put(parser) +} + +// ParseWithPool parses source using a pooled parser. +// Caller must close the returned tree. +func ParseWithPool(ctx context.Context, lang domain.Language, source []byte) (*sitter.Tree, error) { + parser := getPooledParser(lang) + defer putPooledParser(lang, parser) + + tree, err := parser.ParseCtx(ctx, nil, source) + if err != nil { + return nil, fmt.Errorf("parse failed: %w", err) + } + + return tree, nil +} + +type queryCacheKey struct { + lang domain.Language + queryStr string +} + +type cachedQuery struct { + once sync.Once + query *sitter.Query + err error +} + +var queryCache sync.Map + +// getCachedQuery returns a compiled query. The returned query must NOT be closed. +func getCachedQuery(lang domain.Language, queryStr string) (*sitter.Query, error) { + key := queryCacheKey{ + lang: lang, + queryStr: queryStr, + } + + if val, ok := queryCache.Load(key); ok { + cached, ok := val.(*cachedQuery) + if !ok { + return nil, fmt.Errorf("invalid cache entry type") + } + cached.once.Do(func() {}) + return cached.query, cached.err + } + + cached := &cachedQuery{} + actual, loaded := queryCache.LoadOrStore(key, cached) + + if loaded { + var ok bool + cached, ok = actual.(*cachedQuery) + if !ok { + return nil, fmt.Errorf("invalid cache entry type") + } + } + + initLanguages() + + cached.once.Do(func() { + sitterLang := getSitterLanguage(lang) + cached.query, cached.err = sitter.NewQuery([]byte(queryStr), sitterLang) + }) + + return cached.query, cached.err +} + +// QueryWithCache executes a query with cached compilation. +func QueryWithCache(root *sitter.Node, source []byte, lang domain.Language, queryStr string) ([]QueryResult, error) { + query, err := getCachedQuery(lang, queryStr) + if err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + cursor := sitter.NewQueryCursor() + defer cursor.Close() + + cursor.Exec(query, root) + + var results []QueryResult + for { + match, ok := cursor.NextMatch() + if !ok { + break + } + + result := QueryResult{ + Captures: make(map[string]*sitter.Node), + } + + for _, capture := range match.Captures { + name := query.CaptureNameForId(capture.Index) + result.Captures[name] = capture.Node + if result.Node == nil { + result.Node = capture.Node + } + } + + results = append(results, result) + } + + return results, nil +} + +// ClearQueryCache removes all cached queries. Only for testing. +func ClearQueryCache() { + var toClose []*sitter.Query + + queryCache.Range(func(key, value any) bool { + queryCache.Delete(key) + if cached, ok := value.(*cachedQuery); ok { + cached.once.Do(func() {}) + if cached.query != nil && cached.err == nil { + toClose = append(toClose, cached.query) + } + } + return true + }) + + for _, q := range toClose { + q.Close() + } +} diff --git a/src/pkg/parser/parser_pool_test.go b/src/pkg/parser/parser_pool_test.go new file mode 100644 index 0000000..7fd7707 --- /dev/null +++ b/src/pkg/parser/parser_pool_test.go @@ -0,0 +1,437 @@ +package parser + +import ( + "context" + "sync" + "testing" + + "github.com/specvital/core/domain" +) + +func TestParserPool_RaceFree(t *testing.T) { + const goroutines = 100 + const iterations = 10 + + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) { + // test code + }) +} +`) + + var wg sync.WaitGroup + ctx := context.Background() + + // Test concurrent access to parser pool + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for j := 0; j < iterations; j++ { + tree, err := ParseWithPool(ctx, domain.LanguageGo, source) + if err != nil { + t.Errorf("goroutine %d iteration %d: parse failed: %v", id, j, err) + return + } + + if tree == nil { + t.Errorf("goroutine %d iteration %d: tree is nil", id, j) + return + } + + tree.Close() + } + }(i) + } + + wg.Wait() +} + +func TestParserPool_MultipleLanguages(t *testing.T) { + const goroutines = 50 + + sources := map[domain.Language][]byte{ + domain.LanguageGo: []byte(` +package main +func TestExample(t *testing.T) {} +`), + domain.LanguageJavaScript: []byte(` +describe('test', () => { + it('works', () => {}); +}); +`), + domain.LanguageTypeScript: []byte(` +describe('test', () => { + it('works', () => {}); +}); +`), + } + + var wg sync.WaitGroup + ctx := context.Background() + + // Test that different language pools work independently + for lang, source := range sources { + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(l domain.Language, src []byte) { + defer wg.Done() + + tree, err := ParseWithPool(ctx, l, src) + if err != nil { + t.Errorf("language %v: parse failed: %v", l, err) + return + } + defer tree.Close() + + if tree == nil { + t.Errorf("language %v: tree is nil", l) + } + }(lang, source) + } + } + + wg.Wait() +} + +func TestQueryCache_RaceFree(t *testing.T) { + const goroutines = 100 + + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) { + // test code + }) +} +`) + + queryStr := ` +(function_declaration + name: (identifier) @name + parameters: (parameter_list + (parameter_declaration + type: (pointer_type + (qualified_type) @param_type)))) +` + + ctx := context.Background() + + var wg sync.WaitGroup + + // Test concurrent query compilation and caching + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + tree, err := ParseWithPool(ctx, domain.LanguageGo, source) + if err != nil { + t.Errorf("goroutine %d: parse failed: %v", id, err) + return + } + defer tree.Close() + + results, err := QueryWithCache(tree.RootNode(), source, domain.LanguageGo, queryStr) + if err != nil { + t.Errorf("goroutine %d: query failed: %v", id, err) + return + } + + if len(results) == 0 { + t.Errorf("goroutine %d: expected query results", id) + } + }(i) + } + + wg.Wait() +} + +func TestQueryCache_SameQueryReused(t *testing.T) { + defer ClearQueryCache() + + source := []byte(`package main`) + queryStr := `(package_clause)` + + ctx := context.Background() + lang := domain.LanguageGo + + // First call - should compile query + tree1, err := ParseWithPool(ctx, lang, source) + if err != nil { + t.Fatalf("parse 1 failed: %v", err) + } + defer tree1.Close() + + results1, err := QueryWithCache(tree1.RootNode(), source, lang, queryStr) + if err != nil { + t.Fatalf("query 1 failed: %v", err) + } + + // Second call - should use cached query + tree2, err := ParseWithPool(ctx, lang, source) + if err != nil { + t.Fatalf("parse 2 failed: %v", err) + } + defer tree2.Close() + + results2, err := QueryWithCache(tree2.RootNode(), source, lang, queryStr) + if err != nil { + t.Fatalf("query 2 failed: %v", err) + } + + // Results should be identical (cached query works) + if len(results1) != len(results2) { + t.Errorf("result count mismatch: got %d and %d", len(results1), len(results2)) + } +} + +func TestGetPooledParser_ReturnsValidParser(t *testing.T) { + tests := []struct { + name string + lang domain.Language + }{ + {"Go", domain.LanguageGo}, + {"JavaScript", domain.LanguageJavaScript}, + {"TypeScript", domain.LanguageTypeScript}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := getPooledParser(tt.lang) + if parser == nil { + t.Fatal("parser is nil") + } + + // Verify parser can be used + ctx := context.Background() + source := []byte("package main") + if tt.lang == domain.LanguageJavaScript || tt.lang == domain.LanguageTypeScript { + source = []byte("console.log('test');") + } + + tree, err := parser.ParseCtx(ctx, nil, source) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + defer tree.Close() + + // Return to pool + putPooledParser(tt.lang, parser) + + // Get again - should get same or different parser (both valid) + parser2 := getPooledParser(tt.lang) + if parser2 == nil { + t.Fatal("second parser is nil") + } + putPooledParser(tt.lang, parser2) + }) + } +} + +func TestClearQueryCache(t *testing.T) { + queryStr := `(package_clause)` + lang := domain.LanguageGo + + // Add query to cache + _, err := getCachedQuery(lang, queryStr) + if err != nil { + t.Fatalf("failed to cache query: %v", err) + } + + // Clear cache + ClearQueryCache() + + // Query should be recompiled (no error expected, just testing it works) + _, err = getCachedQuery(lang, queryStr) + if err != nil { + t.Fatalf("failed to recompile query after clear: %v", err) + } +} + +// Benchmark parser pool vs direct creation +func BenchmarkParser_Direct(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) { + // test code + }) +} +`) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p := NewTSParser(domain.LanguageGo) + tree, err := p.Parse(ctx, source) + if err != nil { + b.Fatal(err) + } + tree.Close() + } +} + +func BenchmarkParser_Pooled(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) { + // test code + }) +} +`) + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tree, err := ParseWithPool(ctx, domain.LanguageGo, source) + if err != nil { + b.Fatal(err) + } + tree.Close() + } +} + +func BenchmarkParser_PooledParallel(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) { + // test code + }) +} +`) + ctx := context.Background() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + tree, err := ParseWithPool(ctx, domain.LanguageGo, source) + if err != nil { + b.Fatal(err) + } + tree.Close() + } + }) +} + +func BenchmarkQuery_Direct(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) {}) +} +`) + queryStr := ` +(function_declaration + name: (identifier) @name + parameters: (parameter_list)) +` + ctx := context.Background() + lang := domain.LanguageGo + + tree, err := ParseWithPool(ctx, lang, source) + if err != nil { + b.Fatal(err) + } + defer tree.Close() + root := tree.RootNode() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := Query(root, source, lang, queryStr) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQuery_Cached(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) {}) +} +`) + queryStr := ` +(function_declaration + name: (identifier) @name + parameters: (parameter_list)) +` + ctx := context.Background() + lang := domain.LanguageGo + + tree, err := ParseWithPool(ctx, lang, source) + if err != nil { + b.Fatal(err) + } + defer tree.Close() + root := tree.RootNode() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := QueryWithCache(root, source, lang, queryStr) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkQuery_CachedParallel(b *testing.B) { + source := []byte(` +package main + +import "testing" + +func TestExample(t *testing.T) { + t.Run("subtest", func(t *testing.T) {}) +} +`) + queryStr := ` +(function_declaration + name: (identifier) @name + parameters: (parameter_list)) +` + ctx := context.Background() + lang := domain.LanguageGo + + tree, err := ParseWithPool(ctx, lang, source) + if err != nil { + b.Fatal(err) + } + defer tree.Close() + root := tree.RootNode() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := QueryWithCache(root, source, lang, queryStr) + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/src/pkg/parser/strategies/gotesting/gotesting.go b/src/pkg/parser/strategies/gotesting/gotesting.go index 3e4e89b..662800d 100644 --- a/src/pkg/parser/strategies/gotesting/gotesting.go +++ b/src/pkg/parser/strategies/gotesting/gotesting.go @@ -62,9 +62,7 @@ func (s *Strategy) CanHandle(filename string, _ []byte) bool { } func (s *Strategy) Parse(ctx context.Context, source []byte, filename string) (*domain.TestFile, error) { - p := parser.NewTSParser(domain.LanguageGo) - - tree, err := p.Parse(ctx, source) + tree, err := parser.ParseWithPool(ctx, domain.LanguageGo, source) if err != nil { return nil, fmt.Errorf("go-testing parser: failed to parse %s: %w", filename, err) } diff --git a/src/pkg/parser/strategies/jest/jest_test.go b/src/pkg/parser/strategies/jest/jest_test.go index 4b1a622..032c975 100644 --- a/src/pkg/parser/strategies/jest/jest_test.go +++ b/src/pkg/parser/strategies/jest/jest_test.go @@ -314,11 +314,11 @@ func TestStrategy_Parse_Each(t *testing.T) { s := &Strategy{} tests := []struct { - name string - source string - wantCount int - wantFirst string - isSuite bool + name string + source string + wantCount int + wantFirst string + isSuite bool }{ { name: "should parse describe.each", diff --git a/src/pkg/parser/strategies/playwright/playwright.go b/src/pkg/parser/strategies/playwright/playwright.go index db22eb5..577701d 100644 --- a/src/pkg/parser/strategies/playwright/playwright.go +++ b/src/pkg/parser/strategies/playwright/playwright.go @@ -69,9 +69,8 @@ func (s *Strategy) CanHandle(filename string, content []byte) bool { func (s *Strategy) Parse(ctx context.Context, source []byte, filename string) (*domain.TestFile, error) { lang := jstest.DetectLanguage(filename) - p := parser.NewTSParser(lang) - tree, err := p.Parse(ctx, source) + tree, err := parser.ParseWithPool(ctx, lang, source) if err != nil { return nil, fmt.Errorf("playwright parser: failed to parse %s: %w", filename, err) } diff --git a/src/pkg/parser/strategies/shared/jstest/parser.go b/src/pkg/parser/strategies/shared/jstest/parser.go index db29747..59fd9b3 100644 --- a/src/pkg/parser/strategies/shared/jstest/parser.go +++ b/src/pkg/parser/strategies/shared/jstest/parser.go @@ -210,9 +210,7 @@ func ParseNode(node *sitter.Node, source []byte, filename string, file *domain.T func Parse(ctx context.Context, source []byte, filename string, framework string) (*domain.TestFile, error) { lang := DetectLanguage(filename) - p := parser.NewTSParser(lang) - - tree, err := p.Parse(ctx, source) + tree, err := parser.ParseWithPool(ctx, lang, source) if err != nil { return nil, fmt.Errorf("failed to parse %s: %w", filename, err) }