Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,15 @@ private static void sendTrafficMetricsToTelemetry(BasicDBObject metricsData) {
}

public static boolean useHostCondition(String hostName, HttpResponseParams.Source source) {
List<HttpResponseParams.Source> whiteListSource = Arrays.asList(HttpResponseParams.Source.MIRRORING);
List<HttpResponseParams.Source> whiteListSource = Arrays.asList(HttpResponseParams.Source.MIRRORING, HttpResponseParams.Source.MCP_RECON);
boolean hostNameCondition;
if (hostName == null) {
hostNameCondition = false;
} else {
hostNameCondition = ! ( hostName.toLowerCase().equals(hostName.toUpperCase()) );
} else if (source.equals(HttpResponseParams.Source.MCP_RECON)) {
hostNameCondition = true;
}
else {
hostNameCondition = ! ( hostName.toLowerCase().equals(hostName.toUpperCase()) );
}
return whiteListSource.contains(source) && hostNameCondition && ApiCollection.useHost;
}
Expand Down Expand Up @@ -760,7 +763,7 @@ public List<HttpResponseParams> filterHttpResponseParams(List<HttpResponseParams
Map<String, List<ExecutorNode>> executorNodesMap = ParseAndExecute.createExecutorNodeMap(apiCatalogSync.advancedFilterMap);
for (HttpResponseParams httpResponseParam: httpResponseParamsList) {

if (httpResponseParam.getSource().equals(HttpResponseParams.Source.MIRRORING)) {
if (httpResponseParam.getSource().equals(HttpResponseParams.Source.MIRRORING) || httpResponseParam.getSource().equals(HttpResponseParams.Source.MCP_RECON)) {
TrafficMetrics.Key totalRequestsKey = getTrafficMetricsKey(httpResponseParam, TrafficMetrics.Name.TOTAL_REQUESTS_RUNTIME);
incTrafficMetrics(totalRequestsKey,1);
}
Expand Down Expand Up @@ -879,7 +882,7 @@ public List<HttpResponseParams> filterHttpResponseParams(List<HttpResponseParams
loggerMaker.infoAndAddToDb("Adding " + responseParamsList.size() + "new graphql endpoints in inventory");
}

if (httpResponseParam.getSource().equals(HttpResponseParams.Source.MIRRORING)) {
if (httpResponseParam.getSource().equals(HttpResponseParams.Source.MIRRORING) || httpResponseParam.getSource().equals(HttpResponseParams.Source.MCP_RECON)) {
TrafficMetrics.Key processedRequestsKey = getTrafficMetricsKey(httpResponseParam, TrafficMetrics.Name.FILTERED_REQUESTS_RUNTIME);
incTrafficMetrics(processedRequestsKey,1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ public void run() {
}
}, 0, 24, TimeUnit.HOURS);

// schedule MCP Recon Sync job for 2 mins
// schedule MCP Recon Sync job for once in a day
loggerMaker.info("Scheduling MCP Recon Sync Job");
APIConfig finalApiConfigRecon = apiConfig;
scheduler.scheduleAtFixedRate(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,59 +340,26 @@ private void cleanFailedIpCache() {
* Optimized MCP verification using CompletableFuture with timeout management
*/
private List<McpServer> verifyMcpBatch(List<IpPortPair> openTargets) {
// Use CompletableFuture for better control
List<CompletableFuture<McpServer>> futures = new ArrayList<>(openTargets.size());
logger.info(String.format("Starting MCP verification for %d targets", openTargets.size()));
List<McpServer> results = new ArrayList<>();

for (IpPortPair target : openTargets) {
CompletableFuture<McpServer> future = CompletableFuture
.supplyAsync(() -> {
try {
semaphore.acquire();
try {
return verifySingleMcp(target.ip, target.port);
} finally {
semaphore.release();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
}, executorService);

// Add timeout using Java 8 compatible approach
final CompletableFuture<McpServer> timeoutFuture = new CompletableFuture<>();
executorService.submit(() -> {
try {
Thread.sleep(timeout * 2);
timeoutFuture.complete(null);
} catch (InterruptedException e) {
// Ignore
try {
logger.debug(String.format("Checking MCP at %s:%d", target.ip, target.port));
McpServer result = verifySingleMcp(target.ip, target.port);
if (result != null) {
logger.info(String.format("MCP server found at %s:%d", target.ip, target.port));
results.add(result);
} else {
logger.debug(String.format("No MCP server at %s:%d", target.ip, target.port));
}
});

CompletableFuture<McpServer> resultFuture = future
.applyToEither(timeoutFuture, result -> result)
.exceptionally(ex -> {
logger.debug("Verification timeout for " + target.ip + ":" + target.port);
return null;
});

futures.add(resultFuture);
}

// Wait for all completions or timeout
try {
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
.get(timeout * 3, TimeUnit.MILLISECONDS);
} catch (Exception e) {
// Continue with completed futures
} catch (Exception e) {
logger.debug("Verification error for " + target.ip + ":" + target.port + ": " + e.getMessage());
}
}

// Collect non-null results
return futures.stream()
.map(future -> future.getNow(null))
.filter(Objects::nonNull)
.collect(Collectors.toList());
logger.info(String.format("MCP verification completed. Found %d servers out of %d targets", results.size(), openTargets.size()));
return results;
}

/**
Expand Down Expand Up @@ -443,9 +410,6 @@ private McpServer checkSseEndpoint(String baseUrl, String ip, int port) {
if (response.getStatusCode() == 200) {
String contentType = response.getHeaders().getOrDefault("Content-Type", Collections.singletonList("")).get(0);
if (contentType != null && contentType.contains("text/event-stream")) {
String content = response.getBody() != null ? response.getBody().toLowerCase() : "";
for (String indicator : McpConstants.MCP_INDICATORS) {
if (content.contains(indicator.toLowerCase())) {
McpServer server = new McpServer();
server.setIp(ip);
server.setPort(port);
Expand All @@ -456,8 +420,6 @@ private McpServer checkSseEndpoint(String baseUrl, String ip, int port) {
server.setType("SSE");
server.setEndpoint(endpoint);
return server;
}
}
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -539,8 +501,7 @@ private McpServer checkHttpEndpoints(String baseUrl, String ip, int port) {
Map<String, List<String>> headers = new HashMap<>();
OriginalHttpRequest request = new OriginalHttpRequest(url, "", "GET", null, headers, "");
OriginalHttpResponse response = ApiExecutor.sendRequest(request, true, null, false, new ArrayList<>());
String content = response.getBody();
if (isLikelyMcp(content)) {
if(response.getStatusCode() == 200) {
McpServer server = new McpServer();
server.setIp(ip);
server.setPort(port);
Expand All @@ -549,7 +510,7 @@ private McpServer checkHttpEndpoints(String baseUrl, String ip, int port) {
server.setDetectionMethod("HTTP");
server.setTimestamp(new Date().toString());
server.setEndpoint(endpoint);
List<String> detectedIndicators = getDetectedIndicators(content);
List<String> detectedIndicators = getDetectedIndicators("HTTP");
Map<String, Object> serverInfo = new HashMap<>();
serverInfo.put("detectedIndicators", detectedIndicators);
server.setServerInfo(serverInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,23 @@ private void processScanResults(List<ScanTaskResult> scanResults, APIConfig apiC
// Add to batch
serverBatch.add(mcpReconResult);

List<HttpResponseParams> toolsResponseList = toolsDiscovery(server);
List<HttpResponseParams> resourcesResponseList = resourcesDiscovery(server);

List<HttpResponseParams> toolsResponseList = new ArrayList<>();
List<HttpResponseParams> resourcesResponseList = new ArrayList<>();
List<HttpResponseParams> endpointsResponseList = new ArrayList<>();
if(server.getTools() != null) {
toolsResponseList = McpToolsSyncJobExecutor.INSTANCE.handleMcpToolsDiscovery(null, new HashSet<>(), true, server);
}
if(server.getResources() != null) {
resourcesResponseList = McpToolsSyncJobExecutor.INSTANCE.handleMcpResourceDiscovery(null, new HashSet<>(), true, server);
}
if(server.getEndpoint() != null && !server.getEndpoint().isEmpty()){
endpointsResponseList = handleMcpRequestDiscovery(server);
}
List<HttpResponseParams> responseParamsToProcess = new ArrayList<>();
responseParamsToProcess.addAll(toolsResponseList);
responseParamsToProcess.addAll(resourcesResponseList);
responseParamsToProcess.addAll(endpointsResponseList);
McpToolsSyncJobExecutor.processResponseParams(apiConfig, responseParamsToProcess);

// Insert when batch is full
Expand Down Expand Up @@ -494,73 +506,45 @@ private static class ScanCacheEntry {
}
}

public List<HttpResponseParams> handleMcpRequestDiscovery(McpServer mcpServer) {

private List<HttpResponseParams> toolsDiscovery(McpServer mcpServer) {
String host = mcpServer.getUrl();
String host = mcpServer.getIp() + ":" + mcpServer.getPort();
ObjectMapper mapper = new ObjectMapper();

List<HttpResponseParams> responseParamsList = new ArrayList<>();
try {
int id = 1;
String toolsCallRequestHeaders = buildHeaders(host);
for (McpSchema.Tool tool : mcpServer.getTools()) {
try {
String requestHeaders = buildHeaders(host);
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(
McpSchema.JSONRPC_VERSION,
McpSchema.METHOD_TOOLS_CALL,
id++,
new McpSchema.CallToolRequest(tool.getName(), McpToolsSyncJobExecutor.generateExampleArguments(tool.getInputSchema()))
);

HttpResponseParams toolsCallHttpResponseParams = McpToolsSyncJobExecutor.convertToAktoFormat(0,
mcpServer.getUrl(),
toolsCallRequestHeaders,
HttpMethod.POST.name(),
mapper.writeValueAsString(request),
new OriginalHttpResponse("", Collections.emptyMap(), HttpStatus.SC_OK));

if (toolsCallHttpResponseParams != null) {
responseParamsList.add(toolsCallHttpResponseParams);
}
}
} catch (Exception e) {
logger.error("Error while discovering mcp tools for hostname: {}", host, e);
}
return responseParamsList;
}


private List<HttpResponseParams> resourcesDiscovery(McpServer mcpServer) {
String host = mcpServer.getUrl();
ObjectMapper mapper = new ObjectMapper();

List<HttpResponseParams> responseParamsList = new ArrayList<>();
try {
int id = 1;
String resourceCallRequestHeaders = buildHeaders(host);
for (McpSchema.Resource resource : mcpServer.getResources()) {
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(
McpSchema.JSONRPC_VERSION,
McpSchema.METHOD_RESOURCES_READ,
id++,
new McpSchema.ReadResourceRequest(resource.getUri())
);

HttpResponseParams readResourceHttpResponseParams = McpToolsSyncJobExecutor.convertToAktoFormat(0,
mcpServer.getUrl(),
resourceCallRequestHeaders,
HttpMethod.POST.name(),
mapper.writeValueAsString(request),
new OriginalHttpResponse("", Collections.emptyMap(), HttpStatus.SC_OK));
McpSchema.JSONRPC_VERSION,
McpSchema.METHOD_PING,
String.valueOf(1),
new McpSchema.InitializeRequest(
McpSchema.LATEST_PROTOCOL_VERSION,
new McpSchema.ClientCapabilities(
null,
null,
null
),
new McpSchema.Implementation("akto-api-recon-scan", "1.0.0")
)
);

HttpResponseParams requestHttpResponseParams = McpToolsSyncJobExecutor.convertToAktoFormat(0,
mcpServer.getUrl(),
requestHeaders,
HttpMethod.GET.name(),
mapper.writeValueAsString(request),
new OriginalHttpResponse("", Collections.emptyMap(), HttpStatus.SC_OK));

if (requestHttpResponseParams != null) {
requestHttpResponseParams.setSource(HttpResponseParams.Source.MCP_RECON);
responseParamsList.add(requestHttpResponseParams);
}

if (readResourceHttpResponseParams != null) {
responseParamsList.add(readResourceHttpResponseParams);
}
} catch (Exception e) {
logger.error("Error while discovering mcp resources for hostname: {}", host, e);
}
} catch (Exception e) {
logger.error("Error while discovering mcp resources for hostname: {}", host, e);
return responseParamsList;
}
return responseParamsList;
}

private String buildHeaders(String host) {
return "{\"Content-Type\":\"application/json\",\"Accept\":\"*/*\",\"host\":\"" + host + "\"}";
Expand Down
Loading
Loading