Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ This is an [MCP](https://modelcontextprotocol.io/introduction) server that runs

- `definition`: Retrieves the complete source code definition of any symbol (function, type, constant, etc.) from your codebase.
- `references`: Locates all usages and references of a symbol throughout the codebase.
- `incoming_calls`: Find all callers of a function or method throughout the codebase. Shows where the symbol is being called from.
- `diagnostics`: Provides diagnostic information for a specific file, including warnings and errors.
- `hover`: Display documentation, type hints, or other hover information for a given location.
- `rename_symbol`: Rename a symbol across a project.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---

/TEST_OUTPUT/workspace/main.go
Incoming Calls in File: 1
Callers: L12:C6 (main)

12|func main() {
13| fmt.Println(FooBar())
14|}
25 changes: 25 additions & 0 deletions integrationtests/snapshots/go/incoming_calls/helper-function.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---

/TEST_OUTPUT/workspace/another_consumer.go
Incoming Calls in File: 1
Callers: L6:C6 (AnotherConsumer)

6|func AnotherConsumer() {
7| // Use helper function
8| fmt.Println("Another message:", HelperFunction())
9|
10| // Create another SharedStruct instance
11| s := &SharedStruct{

---

/TEST_OUTPUT/workspace/consumer.go
Incoming Calls in File: 1
Callers: L6:C6 (ConsumerFunction)

6|func ConsumerFunction() {
7| message := HelperFunction()
8| fmt.Println(message)
9|
10| // Use shared struct
11| s := &SharedStruct{
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
No incoming calls found for symbol: SharedConstant
12 changes: 12 additions & 0 deletions integrationtests/snapshots/go/incoming_calls/struct-method.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---

/TEST_OUTPUT/workspace/consumer.go
Incoming Calls in File: 1
Callers: L6:C6 (ConsumerFunction)

6|func ConsumerFunction() {
7| message := HelperFunction()
8| fmt.Println(message)
9|
10| // Use shared struct
11| s := &SharedStruct{
99 changes: 99 additions & 0 deletions integrationtests/tests/go/incoming_calls/incoming_calls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package incoming_calls_test

import (
"context"
"strings"
"testing"
"time"

"github.com/isaacphi/mcp-language-server/integrationtests/tests/common"
"github.com/isaacphi/mcp-language-server/integrationtests/tests/go/internal"
"github.com/isaacphi/mcp-language-server/internal/tools"
)

// TestFindIncomingCalls tests the FindIncomingCalls tool with Go symbols
// that have callers in different files
func TestFindIncomingCalls(t *testing.T) {
suite := internal.GetTestSuite(t)

ctx, cancel := context.WithTimeout(suite.Context, 10*time.Second)
defer cancel()

tests := []struct {
name string
symbolName string
expectedText string
expectedFiles int // Number of files where callers should be found
snapshotName string
}{
{
name: "Function called from multiple files",
symbolName: "HelperFunction",
expectedText: "ConsumerFunction",
expectedFiles: 2, // consumer.go and another_consumer.go
snapshotName: "helper-function",
},
{
name: "Function called from same file",
symbolName: "FooBar",
expectedText: "main",
expectedFiles: 1, // main.go
snapshotName: "foobar-function",
},
{
name: "Method with callers",
symbolName: "SharedStruct.Method",
expectedText: "ConsumerFunction",
expectedFiles: 1, // consumer.go or another_consumer.go
snapshotName: "struct-method",
},
{
name: "No callers found",
symbolName: "SharedConstant",
expectedText: "No incoming calls found",
expectedFiles: 0,
snapshotName: "no-callers",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Call the FindIncomingCalls tool
result, err := tools.FindIncomingCalls(ctx, suite.Client, tc.symbolName)
if err != nil {
t.Fatalf("Failed to find incoming calls: %v", err)
}

// Check that the result contains relevant information
if !strings.Contains(result, tc.expectedText) {
t.Errorf("Incoming calls do not contain expected text: %s", tc.expectedText)
}

// Count how many different files are mentioned in the result
fileCount := countFilesInResult(result)
if tc.expectedFiles > 0 && fileCount < tc.expectedFiles {
t.Errorf("Expected incoming calls in at least %d files, but found in %d files",
tc.expectedFiles, fileCount)
}

// Use snapshot testing to verify exact output
common.SnapshotTest(t, "go", "incoming_calls", tc.snapshotName, result)
})
}
}

// countFilesInResult counts the number of unique files mentioned in the result
func countFilesInResult(result string) int {
fileMap := make(map[string]bool)

// Any line containing "workspace" and ".go" is a file path
for line := range strings.SplitSeq(result, "\n") {
if strings.Contains(line, "workspace") && strings.Contains(line, ".go") {
if !strings.Contains(line, "Incoming Calls in File") {
fileMap[line] = true
}
}
}

return len(fileMap)
}
2 changes: 0 additions & 2 deletions integrationtests/workspaces/go/go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
module github.com/isaacphi/mcp-language-server/integrationtests/test-output/go/workspace

go 1.20

require github.com/stretchr/testify v1.8.4 // unused import for codelens test
179 changes: 179 additions & 0 deletions internal/tools/incoming-calls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package tools

import (
"context"
"fmt"
"os"
"sort"
"strconv"
"strings"

"github.com/isaacphi/mcp-language-server/internal/lsp"
"github.com/isaacphi/mcp-language-server/internal/protocol"
)

func FindIncomingCalls(ctx context.Context, client *lsp.Client, symbolName string) (string, error) {
// Get context lines from environment variable
contextLines := 5
if envLines := os.Getenv("LSP_CONTEXT_LINES"); envLines != "" {
if val, err := strconv.Atoi(envLines); err == nil && val >= 0 {
contextLines = val
}
}

// First get the symbol location like ReadDefinition does
symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{
Query: symbolName,
})
if err != nil {
return "", fmt.Errorf("failed to fetch symbol: %v", err)
}

results, err := symbolResult.Results()
if err != nil {
return "", fmt.Errorf("failed to parse results: %v", err)
}

var allIncomingCalls []string
for _, symbol := range results {
// Handle different matching strategies based on the search term
if strings.Contains(symbolName, ".") {
// For qualified names like "Type.Method", check for various matches
parts := strings.Split(symbolName, ".")
methodName := parts[len(parts)-1]

// Try matching the unqualified method name for languages that don't use qualified names in symbols
if symbol.GetName() != symbolName && symbol.GetName() != methodName {
continue
}
} else if symbol.GetName() != symbolName {
// For unqualified names, exact match only
continue
}

// Get the location of the symbol
loc := symbol.GetLocation()

// Open the file
err := client.OpenFile(ctx, loc.URI.Path())
if err != nil {
toolsLogger.Error("Error opening file: %v", err)
continue
}

// Prepare call hierarchy
prepareParams := protocol.CallHierarchyPrepareParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: loc.URI,
},
Position: loc.Range.Start,
},
}

items, err := client.PrepareCallHierarchy(ctx, prepareParams)
if err != nil {
return "", fmt.Errorf("failed to prepare call hierarchy: %v", err)
}

if len(items) == 0 {
continue
}

// Get incoming calls for each item
for _, item := range items {
incomingCallsParams := protocol.CallHierarchyIncomingCallsParams{
Item: item,
}

incomingCalls, err := client.IncomingCalls(ctx, incomingCallsParams)
if err != nil {
return "", fmt.Errorf("failed to get incoming calls: %v", err)
}

if len(incomingCalls) == 0 {
continue
}

// Group calls by file
callsByFile := make(map[protocol.DocumentUri][]protocol.CallHierarchyIncomingCall)
for _, call := range incomingCalls {
callsByFile[call.From.URI] = append(callsByFile[call.From.URI], call)
}

// Get sorted list of URIs
uris := make([]string, 0, len(callsByFile))
for uri := range callsByFile {
uris = append(uris, string(uri))
}
sort.Strings(uris)

// Process each file's calls in sorted order
for _, uriStr := range uris {
uri := protocol.DocumentUri(uriStr)
fileCalls := callsByFile[uri]
filePath := strings.TrimPrefix(uriStr, "file://")

// Format file header
fileInfo := fmt.Sprintf("---\n\n%s\nIncoming Calls in File: %d\n",
filePath,
len(fileCalls),
)

// Format locations with context
fileContent, err := os.ReadFile(filePath)
if err != nil {
// Log error but continue with other files
allIncomingCalls = append(allIncomingCalls, fileInfo+"\nError reading file: "+err.Error())
continue
}

lines := strings.Split(string(fileContent), "\n")

// Track call locations for header display
var locStrings []string
var locations []protocol.Location
for _, call := range fileCalls {
// Add the caller location
loc := protocol.Location{
URI: call.From.URI,
Range: call.From.SelectionRange,
}
locations = append(locations, loc)

locStr := fmt.Sprintf("L%d:C%d (%s)",
call.From.SelectionRange.Start.Line+1,
call.From.SelectionRange.Start.Character+1,
call.From.Name)
locStrings = append(locStrings, locStr)
}

// Collect lines to display using the utility function
linesToShow, err := GetLineRangesToDisplay(ctx, client, locations, len(lines), contextLines)
if err != nil {
// Log error but continue with other files
continue
}

// Convert to line ranges using the utility function
lineRanges := ConvertLinesToRanges(linesToShow, len(lines))

// Format with locations in header
formattedOutput := fileInfo
if len(locStrings) > 0 {
formattedOutput += "Callers: " + strings.Join(locStrings, ", ") + "\n"
}

// Format the content with ranges
formattedOutput += "\n" + FormatLinesWithRanges(lines, lineRanges)
allIncomingCalls = append(allIncomingCalls, formattedOutput)
}
}
}

if len(allIncomingCalls) == 0 {
return fmt.Sprintf("No incoming calls found for symbol: %s", symbolName), nil
}

return strings.Join(allIncomingCalls, "\n"), nil
}
24 changes: 24 additions & 0 deletions tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,30 @@ func (s *mcpServer) registerTools() error {
return mcp.NewToolResultText(text), nil
})

incomingCallsTool := mcp.NewTool("incoming_calls",
mcp.WithDescription("Find all callers of a function or method throughout the codebase. Shows where the symbol is being called from (incoming calls)."),
mcp.WithString("symbolName",
mcp.Required(),
mcp.Description("The name of the function or method to find callers for (e.g. 'mypackage.MyFunction', 'MyType.MyMethod')"),
),
)

s.mcpServer.AddTool(incomingCallsTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Extract arguments
symbolName, ok := request.Params.Arguments["symbolName"].(string)
if !ok {
return mcp.NewToolResultError("symbolName must be a string"), nil
}

coreLogger.Debug("Executing incoming_calls for symbol: %s", symbolName)
text, err := tools.FindIncomingCalls(s.ctx, s.lspClient, symbolName)
if err != nil {
coreLogger.Error("Failed to find incoming calls: %v", err)
return mcp.NewToolResultError(fmt.Sprintf("failed to find incoming calls: %v", err)), nil
}
return mcp.NewToolResultText(text), nil
})

coreLogger.Info("Successfully registered all MCP tools")
return nil
}