diff --git a/src/pkg/parser/strategies/gotesting/gotesting.go b/src/pkg/parser/strategies/gotesting/gotesting.go new file mode 100644 index 0000000..3e4e89b --- /dev/null +++ b/src/pkg/parser/strategies/gotesting/gotesting.go @@ -0,0 +1,246 @@ +package gotesting + +import ( + "context" + "fmt" + "path/filepath" + "strconv" + "strings" + "unicode" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/specvital/core/domain" + "github.com/specvital/core/parser" + "github.com/specvital/core/parser/strategies" +) + +const ( + frameworkName = "go-testing" + + // AST node types + nodeCallExpression = "call_expression" + nodeFunctionDeclaration = "function_declaration" + nodeParameterDeclaration = "parameter_declaration" + nodePointerType = "pointer_type" + nodeQualifiedType = "qualified_type" + nodeSelectorExpression = "selector_expression" + + // String literal types + nodeInterpretedStringLiteral = "interpreted_string_literal" + nodeRawStringLiteral = "raw_string_literal" + + // Go test identifiers + methodRun = "Run" + typeTestingParam = "testing.T" +) + +type Strategy struct{} + +func NewStrategy() *Strategy { + return &Strategy{} +} + +func RegisterDefault() { + strategies.Register(NewStrategy()) +} + +func (s *Strategy) Name() string { + return frameworkName +} + +func (s *Strategy) Priority() int { + return strategies.DefaultPriority +} + +func (s *Strategy) Languages() []domain.Language { + return []domain.Language{domain.LanguageGo} +} + +func (s *Strategy) CanHandle(filename string, _ []byte) bool { + return isGoTestFile(filename) +} + +func (s *Strategy) Parse(ctx context.Context, source []byte, filename string) (*domain.TestFile, error) { + p := parser.NewTSParser(domain.LanguageGo) + + tree, err := p.Parse(ctx, source) + if err != nil { + return nil, fmt.Errorf("go-testing parser: failed to parse %s: %w", filename, err) + } + defer tree.Close() + root := tree.RootNode() + + suites, tests := parseTestFunctions(root, source, filename) + + testFile := &domain.TestFile{ + Path: filename, + Language: domain.LanguageGo, + Framework: frameworkName, + Suites: suites, + Tests: tests, + } + + return testFile, nil +} + +// Helper functions (alphabetically ordered) + +func extractSubtests(body *sitter.Node, source []byte, filename string) []domain.Test { + var subtests []domain.Test + + parser.WalkTree(body, func(node *sitter.Node) bool { + if node.Type() != nodeCallExpression { + return true + } + + funcNode := node.ChildByFieldName("function") + if funcNode == nil || funcNode.Type() != nodeSelectorExpression { + return true + } + + field := funcNode.ChildByFieldName("field") + if field == nil || parser.GetNodeText(field, source) != methodRun { + return true + } + + args := node.ChildByFieldName("arguments") + if args == nil { + return true + } + + name := extractSubtestName(args, source) + if name == "" { + return true + } + + subtests = append(subtests, domain.Test{ + Name: name, + Status: domain.TestStatusPending, + Location: parser.GetLocation(node, filename), + }) + + return true + }) + + return subtests +} + +func extractSubtestName(args *sitter.Node, source []byte) string { + for i := 0; i < int(args.ChildCount()); i++ { + child := args.Child(i) + switch child.Type() { + case nodeInterpretedStringLiteral, nodeRawStringLiteral: + return trimQuotes(parser.GetNodeText(child, source)) + } + } + return "" +} + +func extractTestName(funcDecl *sitter.Node, source []byte) string { + nameNode := funcDecl.ChildByFieldName("name") + if nameNode == nil { + return "" + } + return parser.GetNodeText(nameNode, source) +} + +func isGoTestFile(filename string) bool { + base := filepath.Base(filename) + return strings.HasSuffix(base, "_test.go") +} + +func isTestFunction(name string) bool { + if !strings.HasPrefix(name, "Test") || len(name) <= 4 { + return false + } + return unicode.IsUpper(rune(name[4])) +} + +func parseTestFunctions(root *sitter.Node, source []byte, filename string) ([]domain.TestSuite, []domain.Test) { + var suites []domain.TestSuite + var tests []domain.Test + + for i := 0; i < int(root.ChildCount()); i++ { + child := root.Child(i) + if child.Type() != nodeFunctionDeclaration { + continue + } + + name := extractTestName(child, source) + if !isTestFunction(name) { + continue + } + + if !validateTestParams(child, source) { + continue + } + + body := child.ChildByFieldName("body") + var subtests []domain.Test + if body != nil { + subtests = extractSubtests(body, source, filename) + } + + if len(subtests) > 0 { + suite := domain.TestSuite{ + Name: name, + Status: domain.TestStatusPending, + Location: parser.GetLocation(child, filename), + Tests: subtests, + } + suites = append(suites, suite) + } else { + test := domain.Test{ + Name: name, + Status: domain.TestStatusPending, + Location: parser.GetLocation(child, filename), + } + tests = append(tests, test) + } + } + + return suites, tests +} + +func trimQuotes(s string) string { + if unquoted, err := strconv.Unquote(s); err == nil { + return unquoted + } + // Fallback for invalid literals, e.g. from incomplete code. + if len(s) >= 2 && s[0] == s[len(s)-1] && (s[0] == '"' || s[0] == '`') { + return s[1 : len(s)-1] + } + return s +} + +func validateTestParams(funcDecl *sitter.Node, source []byte) bool { + params := funcDecl.ChildByFieldName("parameters") + if params == nil { + return false + } + + var paramDecl *sitter.Node + paramCount := 0 + for i := 0; i < int(params.ChildCount()); i++ { + child := params.Child(i) + if child.Type() == nodeParameterDeclaration { + if paramCount == 0 { + paramDecl = child + } + paramCount++ + } + } + + if paramCount != 1 { + return false + } + + typeNode := paramDecl.ChildByFieldName("type") + if typeNode == nil || typeNode.Type() != nodePointerType { + return false + } + + elem := parser.FindChildByType(typeNode, nodeQualifiedType) + return elem != nil && parser.GetNodeText(elem, source) == typeTestingParam +} diff --git a/src/pkg/parser/strategies/gotesting/gotesting_test.go b/src/pkg/parser/strategies/gotesting/gotesting_test.go new file mode 100644 index 0000000..d3ffa7b --- /dev/null +++ b/src/pkg/parser/strategies/gotesting/gotesting_test.go @@ -0,0 +1,403 @@ +package gotesting + +import ( + "context" + "testing" + + "github.com/specvital/core/domain" +) + +func TestStrategy_Name(t *testing.T) { + s := NewStrategy() + if got := s.Name(); got != "go-testing" { + t.Errorf("Name() = %v, want %v", got, "go-testing") + } +} + +func TestStrategy_Priority(t *testing.T) { + s := NewStrategy() + if got := s.Priority(); got != 100 { + t.Errorf("Priority() = %v, want %v", got, 100) + } +} + +func TestStrategy_Languages(t *testing.T) { + s := NewStrategy() + langs := s.Languages() + if len(langs) != 1 || langs[0] != domain.LanguageGo { + t.Errorf("Languages() = %v, want [go]", langs) + } +} + +func TestStrategy_CanHandle(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + filename string + want bool + }{ + {"valid test file", "user_test.go", true}, + {"test file in directory", "pkg/service/user_test.go", true}, + {"non-test go file", "user.go", false}, + {"typescript test file", "user.test.ts", false}, + {"javascript test file", "user.spec.js", false}, + {"test directory", "test/main.go", false}, + } + + s := NewStrategy() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := s.CanHandle(tt.filename, nil); got != tt.want { + t.Errorf("CanHandle(%v) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} + +func TestStrategy_Parse_SimpleTestFunction(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestCreate(t *testing.T) { + // test implementation +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if result.Framework != "go-testing" { + t.Errorf("Framework = %v, want go-testing", result.Framework) + } + + if result.Language != domain.LanguageGo { + t.Errorf("Language = %v, want go", result.Language) + } + + if len(result.Tests) != 1 { + t.Fatalf("len(Tests) = %v, want 1", len(result.Tests)) + } + + if result.Tests[0].Name != "TestCreate" { + t.Errorf("Tests[0].Name = %v, want TestCreate", result.Tests[0].Name) + } +} + +func TestStrategy_Parse_MultipleTestFunctions(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestCreate(t *testing.T) {} +func TestUpdate(t *testing.T) {} +func TestDelete(t *testing.T) {} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 3 { + t.Fatalf("len(Tests) = %v, want 3", len(result.Tests)) + } + + expectedNames := []string{"TestCreate", "TestUpdate", "TestDelete"} + for i, name := range expectedNames { + if result.Tests[i].Name != name { + t.Errorf("Tests[%d].Name = %v, want %v", i, result.Tests[i].Name, name) + } + } +} + +func TestStrategy_Parse_Subtests(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestUser(t *testing.T) { + t.Run("create", func(t *testing.T) {}) + t.Run("update", func(t *testing.T) {}) +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 0 { + t.Errorf("len(Tests) = %v, want 0 (should be in Suites)", len(result.Tests)) + } + + if len(result.Suites) != 1 { + t.Fatalf("len(Suites) = %v, want 1", len(result.Suites)) + } + + suite := result.Suites[0] + if suite.Name != "TestUser" { + t.Errorf("Suites[0].Name = %v, want TestUser", suite.Name) + } + + if len(suite.Tests) != 2 { + t.Fatalf("len(Suites[0].Tests) = %v, want 2", len(suite.Tests)) + } + + expectedSubtests := []string{"create", "update"} + for i, name := range expectedSubtests { + if suite.Tests[i].Name != name { + t.Errorf("Suites[0].Tests[%d].Name = %v, want %v", i, suite.Tests[i].Name, name) + } + } +} + +func TestStrategy_Parse_MixedTestsAndSubtests(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestSimple(t *testing.T) {} + +func TestWithSubtests(t *testing.T) { + t.Run("sub1", func(t *testing.T) {}) + t.Run("sub2", func(t *testing.T) {}) +} + +func TestAnotherSimple(t *testing.T) {} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 2 { + t.Errorf("len(Tests) = %v, want 2", len(result.Tests)) + } + + if len(result.Suites) != 1 { + t.Fatalf("len(Suites) = %v, want 1", len(result.Suites)) + } +} + +func TestStrategy_Parse_IgnoresNonTestFunctions(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestValid(t *testing.T) {} +func helperFunction() {} +func Test(t *testing.T) {} // Too short, should be ignored +func BenchmarkSomething(b *testing.B) {} // Benchmark, not test +func Test1(t *testing.T) {} // lowercase after Test, should be ignored +func Test_(t *testing.T) {} // underscore after Test, should be ignored +func TestWrongParams(i int, t *testing.T) {} // wrong parameters, should be ignored +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 1 { + t.Fatalf("len(Tests) = %v, want 1", len(result.Tests)) + } + + if result.Tests[0].Name != "TestValid" { + t.Errorf("Tests[0].Name = %v, want TestValid", result.Tests[0].Name) + } +} + +func TestStrategy_Parse_Location(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestCreate(t *testing.T) { + // test +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 1 { + t.Fatalf("len(Tests) = %v, want 1", len(result.Tests)) + } + + loc := result.Tests[0].Location + if loc.File != "user_test.go" { + t.Errorf("Location.File = %v, want user_test.go", loc.File) + } + if loc.StartLine != 5 { + t.Errorf("Location.StartLine = %v, want 5", loc.StartLine) + } +} + +func TestStrategy_Parse_VerificationExample(t *testing.T) { + // Test from plan.md verification method + source := []byte(`package user + +import "testing" + +func TestUser(t *testing.T) { + t.Run("create", func(t *testing.T) {}) +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Should have TestUser suite + create subtest + if len(result.Suites) != 1 { + t.Fatalf("len(Suites) = %v, want 1", len(result.Suites)) + } + + if result.Suites[0].Name != "TestUser" { + t.Errorf("Suites[0].Name = %v, want TestUser", result.Suites[0].Name) + } + + if len(result.Suites[0].Tests) != 1 { + t.Fatalf("len(Suites[0].Tests) = %v, want 1", len(result.Suites[0].Tests)) + } + + if result.Suites[0].Tests[0].Name != "create" { + t.Errorf("Suites[0].Tests[0].Name = %v, want create", result.Suites[0].Tests[0].Name) + } +} + +func TestStrategy_Parse_RawStringLiteral(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestUser(t *testing.T) { + t.Run(` + "`" + `raw string name` + "`" + `, func(t *testing.T) {}) +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Suites) != 1 { + t.Fatalf("len(Suites) = %v, want 1", len(result.Suites)) + } + + if len(result.Suites[0].Tests) != 1 { + t.Fatalf("len(Suites[0].Tests) = %v, want 1", len(result.Suites[0].Tests)) + } + + if result.Suites[0].Tests[0].Name != "raw string name" { + t.Errorf("subtest name = %v, want 'raw string name'", result.Suites[0].Tests[0].Name) + } +} + +func TestStrategy_Parse_NestedSubtests(t *testing.T) { + source := []byte(`package user + +import "testing" + +func TestUser(t *testing.T) { + t.Run("level1", func(t *testing.T) { + t.Run("level2", func(t *testing.T) {}) + }) +} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Suites) != 1 { + t.Fatalf("len(Suites) = %v, want 1", len(result.Suites)) + } + + // Both subtests should be extracted (flat structure) + if len(result.Suites[0].Tests) != 2 { + t.Fatalf("len(Suites[0].Tests) = %v, want 2", len(result.Suites[0].Tests)) + } +} + +func TestStrategy_Parse_EmptySource(t *testing.T) { + source := []byte(``) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "empty_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 0 { + t.Errorf("len(Tests) = %v, want 0", len(result.Tests)) + } + if len(result.Suites) != 0 { + t.Errorf("len(Suites) = %v, want 0", len(result.Suites)) + } +} + +func TestStrategy_Parse_NoTestFunctions(t *testing.T) { + source := []byte(`package user + +func helperFunction() {} +func anotherHelper(s string) int { return len(s) } +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(result.Tests) != 0 { + t.Errorf("len(Tests) = %v, want 0", len(result.Tests)) + } +} + +func TestStrategy_Parse_TestFunctionWithNoBody(t *testing.T) { + // Interface method declaration style (no body) + source := []byte(`package user + +import "testing" + +type Tester interface { + TestSomething(t *testing.T) +} + +func TestReal(t *testing.T) {} +`) + + s := NewStrategy() + result, err := s.Parse(context.Background(), source, "user_test.go") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Should only find TestReal, not interface method + if len(result.Tests) != 1 { + t.Fatalf("len(Tests) = %v, want 1", len(result.Tests)) + } + if result.Tests[0].Name != "TestReal" { + t.Errorf("Tests[0].Name = %v, want TestReal", result.Tests[0].Name) + } +} diff --git a/src/pkg/parser/treesitter.go b/src/pkg/parser/treesitter.go index b9161c3..b5113a4 100644 --- a/src/pkg/parser/treesitter.go +++ b/src/pkg/parser/treesitter.go @@ -6,6 +6,7 @@ import ( "sync" sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/golang" "github.com/smacker/go-tree-sitter/javascript" "github.com/smacker/go-tree-sitter/typescript/typescript" @@ -20,8 +21,9 @@ type TSParser struct { } var ( - tsLang *sitter.Language + goLang *sitter.Language jsLang *sitter.Language + tsLang *sitter.Language langOnce sync.Once ) @@ -33,13 +35,16 @@ type QueryResult struct { func initLanguages() { langOnce.Do(func() { - tsLang = typescript.GetLanguage() + goLang = golang.GetLanguage() jsLang = javascript.GetLanguage() + tsLang = typescript.GetLanguage() }) } func getSitterLanguage(lang domain.Language) *sitter.Language { switch lang { + case domain.LanguageGo: + return goLang case domain.LanguageJavaScript: return jsLang default: @@ -51,15 +56,7 @@ func NewTSParser(lang domain.Language) *TSParser { initLanguages() parser := sitter.NewParser() - - switch lang { - case domain.LanguageTypeScript: - parser.SetLanguage(tsLang) - case domain.LanguageJavaScript: - parser.SetLanguage(jsLang) - default: - parser.SetLanguage(tsLang) - } + parser.SetLanguage(getSitterLanguage(lang)) return &TSParser{ parser: parser, diff --git a/src/pkg/parser/treesitter_test.go b/src/pkg/parser/treesitter_test.go index dc3cfa3..c5136bc 100644 --- a/src/pkg/parser/treesitter_test.go +++ b/src/pkg/parser/treesitter_test.go @@ -18,7 +18,7 @@ func TestNewTSParser(t *testing.T) { }{ {"should create parser for TypeScript", domain.LanguageTypeScript}, {"should create parser for JavaScript", domain.LanguageJavaScript}, - {"should default to TypeScript for unknown language", domain.LanguageGo}, + {"should create parser for Go", domain.LanguageGo}, } for _, tt := range tests {