diff --git a/apps/database-abstractor/src/main/java/com/akto/action/DbAction.java b/apps/database-abstractor/src/main/java/com/akto/action/DbAction.java index d389c28933..246ed9f800 100644 --- a/apps/database-abstractor/src/main/java/com/akto/action/DbAction.java +++ b/apps/database-abstractor/src/main/java/com/akto/action/DbAction.java @@ -9,6 +9,14 @@ import com.akto.dto.*; import com.akto.dto.ApiInfo.ApiInfoKey; import com.akto.dto.McpReconRequest; +import com.akto.dao.mcp.MCPGuardrailYamlTemplateDao; +import com.akto.dto.mcp.MCPGuardrailConfig; +import com.akto.dto.mcp.MCPGuardrailConfigYamlParser; +import com.akto.dto.mcp.MCPGuardrailType; +import com.akto.utils.MCPGuardrailUtil; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.Updates; +import org.bson.conversions.Bson; import com.akto.dto.billing.Organization; import com.akto.dto.billing.Tokens; import com.akto.dto.billing.UningestedApiOverage; @@ -44,6 +52,7 @@ import com.akto.notifications.slack.SlackSender; import com.akto.util.enums.GlobalEnums; import com.akto.utils.CustomAuthUtil; +import com.akto.util.Constants; import com.akto.utils.KafkaUtils; import com.akto.utils.RedactAlert; import com.akto.utils.SampleDataLogs; @@ -137,6 +146,31 @@ public class DbAction extends ActionSupport { @Getter @Setter private List serverDataList; + // MCP Guardrails fields + @Getter @Setter + private List mcpGuardrailTemplates; + + @Getter @Setter + private Map mcpGuardrailConfigs; + + @Getter @Setter + private YamlTemplate mcpGuardrailTemplate; + + @Getter @Setter + private String templateId; + + @Getter @Setter + private String guardrailType; + + @Getter @Setter + private boolean includeYamlContent = true; + + @Getter @Setter + private boolean activeOnly = true; + + @Getter @Setter + private String content; + private ModuleInfo moduleInfo; private static final LoggerMaker loggerMaker = new LoggerMaker(DbAction.class, LogDb.DB_ABS); @@ -2874,6 +2908,160 @@ public String storeMcpReconResultsBatch() { return Action.SUCCESS.toUpperCase(); } + /** + * Fetch all MCP Guardrail YAML templates + */ + public String fetchMCPGuardrailTemplates() { + try { + if (activeOnly) { + mcpGuardrailTemplates = MCPGuardrailYamlTemplateDao.instance.fetchActiveTemplates(); + } else { + mcpGuardrailTemplates = MCPGuardrailYamlTemplateDao.instance.findAll(Filters.empty()); + } + + loggerMaker.infoAndAddToDb("Fetched " + mcpGuardrailTemplates.size() + " MCP Guardrail templates", LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error fetching MCP Guardrail templates: " + e.toString()); + addActionError("Failed to fetch MCP Guardrail templates"); + return Action.ERROR.toUpperCase(); + } + } + + /** + * Fetch MCP Guardrail templates by type + */ + public String fetchMCPGuardrailTemplatesByType() { + try { + if (guardrailType == null || guardrailType.trim().isEmpty()) { + addActionError("Guardrail type is required"); + return Action.ERROR.toUpperCase(); + } + + mcpGuardrailTemplates = MCPGuardrailYamlTemplateDao.instance.fetchTemplatesByType(guardrailType.toUpperCase()); + + loggerMaker.infoAndAddToDb("Fetched " + mcpGuardrailTemplates.size() + " MCP Guardrail templates for type: " + guardrailType, LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error fetching MCP Guardrail templates by type: " + e.toString()); + addActionError("Failed to fetch MCP Guardrail templates by type"); + return Action.ERROR.toUpperCase(); + } + } + + /** + * Fetch parsed MCP Guardrail configurations + */ + public String fetchMCPGuardrailConfigs() { + try { + mcpGuardrailConfigs = MCPGuardrailYamlTemplateDao.instance.fetchMCPGuardrailConfig(includeYamlContent); + + loggerMaker.infoAndAddToDb("Fetched " + mcpGuardrailConfigs.size() + " MCP Guardrail configurations", LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error fetching MCP Guardrail configurations: " + e.toString()); + addActionError("Failed to fetch MCP Guardrail configurations"); + return Action.ERROR.toUpperCase(); + } + } + + /** + * Fetch a specific MCP Guardrail template by ID + */ + public String fetchMCPGuardrailTemplate() { + try { + if (templateId == null || templateId.trim().isEmpty()) { + addActionError("Template ID is required"); + return Action.ERROR.toUpperCase(); + } + + mcpGuardrailTemplate = MCPGuardrailYamlTemplateDao.instance.findOne("id", templateId); + + if (mcpGuardrailTemplate == null) { + addActionError("MCP Guardrail template not found"); + return Action.ERROR.toUpperCase(); + } + + loggerMaker.infoAndAddToDb("Fetched MCP Guardrail template: " + templateId, LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error fetching MCP Guardrail template: " + e.toString()); + addActionError("Failed to fetch MCP Guardrail template"); + return Action.ERROR.toUpperCase(); + } + } + + /** + * Get available MCP Guardrail types + */ + public String fetchMCPGuardrailTypes() { + try { + // Return all available types as a simple list + MCPGuardrailType[] types = MCPGuardrailType.values(); + StringBuilder typesJson = new StringBuilder("["); + for (int i = 0; i < types.length; i++) { + if (i > 0) typesJson.append(","); + typesJson.append("\"").append(types[i].name()).append("\""); + } + typesJson.append("]"); + + loggerMaker.infoAndAddToDb("Fetched MCP Guardrail types", LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error fetching MCP Guardrail types: " + e.toString()); + addActionError("Failed to fetch MCP Guardrail types"); + return Action.ERROR.toUpperCase(); + } + } + + /** + * Save MCP Guardrail YAML template + */ + public String saveMCPGuardrailTemplate() { + try { + if (content == null || content.trim().isEmpty()) { + addActionError("Template content is required"); + + return Action.ERROR.toUpperCase(); + } + + // Parse the YAML content to validate it + MCPGuardrailConfig guardrailConfig = MCPGuardrailConfigYamlParser.parseTemplate(content); + + if (guardrailConfig.getId() == null || guardrailConfig.getId().trim().isEmpty()) { + addActionError("Template ID is required"); + return Action.ERROR.toUpperCase(); + } + + if (guardrailConfig.getFilter() == null) { + addActionError("Template filter configuration is required"); + return Action.ERROR.toUpperCase(); + } + + if (!guardrailConfig.getFilter().getIsValid()) { + addActionError("Invalid filter configuration: " + guardrailConfig.getFilter().getErrMsg()); + return Action.ERROR.toUpperCase(); + } + + // Get database updates for the template + String userEmail = "system"; // Default user email, can be overridden by request + List updates = MCPGuardrailUtil.getDbUpdateForTemplate(content, userEmail); + + // Update or insert the template + MCPGuardrailYamlTemplateDao.instance.updateOne( + Filters.eq(Constants.ID, guardrailConfig.getId()), + Updates.combine(updates)); + + loggerMaker.infoAndAddToDb("Saved MCP Guardrail template with ID: " + guardrailConfig.getId(), LogDb.DASHBOARD); + return Action.SUCCESS.toUpperCase(); + + } catch (Exception e) { + loggerMaker.errorAndAddToDb(e, "Error saving MCP Guardrail template: " + e.toString()); + addActionError("Failed to save MCP Guardrail template: " + e.getMessage()); + return Action.ERROR.toUpperCase(); + } + } + public List getCustomDataTypes() { return customDataTypes; } diff --git a/apps/database-abstractor/src/main/resources/struts.xml b/apps/database-abstractor/src/main/resources/struts.xml index d1713349a0..8e4b180290 100644 --- a/apps/database-abstractor/src/main/resources/struts.xml +++ b/apps/database-abstractor/src/main/resources/struts.xml @@ -1793,6 +1793,73 @@ + + + + + + + 422 + false + ^actionErrors.* + + + + + + + + + 422 + false + ^actionErrors.* + + + + + + + + + 422 + false + ^actionErrors.* + + + + + + + + + 422 + false + ^actionErrors.* + + + + + + + + + 422 + false + ^actionErrors.* + + + + + + + + + 422 + false + ^actionErrors.* + + + diff --git a/apps/mcp-guardrails/PROJECT_STRUCTURE.md b/apps/mcp-guardrails/PROJECT_STRUCTURE.md deleted file mode 100644 index 0519ecba6e..0000000000 --- a/apps/mcp-guardrails/PROJECT_STRUCTURE.md +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/apps/mcp-guardrails/README.md b/apps/mcp-guardrails/README.md deleted file mode 100644 index fcd5559cb1..0000000000 --- a/apps/mcp-guardrails/README.md +++ /dev/null @@ -1,322 +0,0 @@ -# MCP Guardrails Library - -A Go library for implementing security guardrails in Model Context Protocol (MCP) proxy servers. This library provides comprehensive data sanitization, content filtering, rate limiting, and input validation capabilities to secure MCP communications. - -## Features - -- **Data Sanitization**: Automatically redact sensitive information like credit cards, SSNs, emails, API keys, and passwords -- **Content Filtering**: Block or warn about malicious content, SQL injection attempts, XSS, and other security threats -- **Rate Limiting**: Implement token bucket rate limiting to prevent abuse -- **Input Validation**: Validate MCP requests and parameters -- **Output Filtering**: Filter and sanitize MCP responses -- **Comprehensive Logging**: Detailed logging of all guardrail operations -- **Extensible**: Easy to add custom patterns and filters - -## Installation - -```bash -go get github.com/akto/mcp-guardrails -``` - -## Quick Start - -```go -package main - -import ( - "encoding/json" - "log" - "time" - - "github.com/akto/mcp-guardrails" -) - -func main() { - // Create guardrail configuration - config := &guardrails.GuardrailConfig{ - EnableDataSanitization: true, - SensitiveFields: []string{"password", "api_key", "secret"}, - - EnableContentFiltering: true, - BlockedKeywords: []string{"malicious", "dangerous"}, - - EnableRateLimiting: true, - RateLimitConfig: guardrails.RateLimitConfig{ - RequestsPerMinute: 100, - BurstSize: 10, - WindowSize: time.Minute, - }, - - EnableInputValidation: true, - ValidationRules: map[string]string{ - "method": "required", - }, - - EnableLogging: true, - LogLevel: "INFO", - } - - // Create guardrail engine - engine := guardrails.NewGuardrailEngine(config) - - // Process MCP response - response := &guardrails.MCPResponse{ - ID: "1", - Result: json.RawMessage(`{ - "user": { - "name": "John Doe", - "email": "john@example.com", - "password": "secret123", - "credit_card": "1234-5678-9012-3456" - } - }`), - } - - result := engine.ProcessResponse(response) - - if result.Blocked { - log.Printf("Response blocked: %s", result.BlockReason) - return - } - - if len(result.Warnings) > 0 { - log.Printf("Warnings: %v", result.Warnings) - } - - // Use sanitized response - sanitizedResponse := result.SanitizedResponse - log.Printf("Response sanitized successfully") -} -``` - -## Configuration - -### GuardrailConfig - -The main configuration struct that controls all guardrail features: - -```go -type GuardrailConfig struct { - // Data Sanitization - EnableDataSanitization bool `json:"enable_data_sanitization"` - SensitiveFields []string `json:"sensitive_fields"` - RedactionPatterns []string `json:"redaction_patterns"` - - // Content Filtering - EnableContentFiltering bool `json:"enable_content_filtering"` - BlockedKeywords []string `json:"blocked_keywords"` - AllowedDomains []string `json:"allowed_domains"` - - // Rate Limiting - EnableRateLimiting bool `json:"enable_rate_limiting"` - RateLimitConfig RateLimitConfig `json:"rate_limit_config"` - - // Input Validation - EnableInputValidation bool `json:"enable_input_validation"` - ValidationRules map[string]string `json:"validation_rules"` - - // Output Filtering - EnableOutputFiltering bool `json:"enable_output_filtering"` - OutputFilters []string `json:"output_filters"` - - // Logging - EnableLogging bool `json:"enable_logging"` - LogLevel string `json:"log_level"` -} -``` - -### Rate Limit Configuration - -```go -type RateLimitConfig struct { - RequestsPerMinute int `json:"requests_per_minute"` - BurstSize int `json:"burst_size"` - WindowSize time.Duration `json:"window_size"` -} -``` - -## Usage Examples - -### Data Sanitization - -```go -config := &guardrails.GuardrailConfig{ - EnableDataSanitization: true, - SensitiveFields: []string{"password", "api_key", "secret"}, -} - -engine := guardrails.NewGuardrailEngine(config) - -// Add custom sensitive pattern -customPattern := guardrails.SensitiveDataPattern{ - Name: "custom_id", - Pattern: `\b[A-Z]{2}\d{6}\b`, - Replacement: "***REDACTED_ID***", - Description: "Custom ID format", -} -engine.AddSensitivePattern(customPattern) -``` - -### Content Filtering - -```go -config := &guardrails.GuardrailConfig{ - EnableContentFiltering: true, -} - -engine := guardrails.NewGuardrailEngine(config) - -// Add custom content filter -filter := guardrails.ContentFilter{ - Type: "keyword", - Pattern: "internal", - Action: "warn", - Description: "Internal information detected", -} -engine.AddContentFilter(filter) -``` - -### Rate Limiting - -```go -config := &guardrails.GuardrailConfig{ - EnableRateLimiting: true, - RateLimitConfig: guardrails.RateLimitConfig{ - RequestsPerMinute: 60, - BurstSize: 5, - WindowSize: time.Minute, - }, -} - -engine := guardrails.NewGuardrailEngine(config) - -// Process requests (rate limiting is automatic) -request := &guardrails.MCPRequest{ - ID: "test", - Method: "tools/list", -} - -result := engine.ProcessRequest(request) -if result.Blocked { - log.Printf("Request blocked: %s", result.BlockReason) -} -``` - -### Input Validation - -```go -config := &guardrails.GuardrailConfig{ - EnableInputValidation: true, - ValidationRules: map[string]string{ - "method": "required", - "id": "alphanumeric", - }, -} - -engine := guardrails.NewGuardrailEngine(config) -``` - -## Default Patterns - -The library includes built-in patterns for common sensitive data: - -- **Credit Card Numbers**: `\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b` -- **Social Security Numbers**: `\b\d{3}-\d{2}-\d{4}\b` -- **Email Addresses**: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b` -- **Phone Numbers**: `\b\d{3}[-.]?\d{3}[-.]?\d{4}\b` -- **API Keys**: `(api[_-]?key|access[_-]?token|secret[_-]?key)\s*[:=]\s*["']?[A-Za-z0-9]{20,}["']?` -- **Passwords**: `(password|passwd|pwd)\s*[:=]\s*["']?[^"'\s]+["']?` -- **IP Addresses**: `\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b` -- **AWS Access Keys**: `AKIA[0-9A-Z]{16}` -- **Private Keys**: `-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----` - -## Default Content Filters - -Built-in content filters for security threats: - -- **Code Execution**: `(eval|exec|system|shell_exec)` -- **SQL Injection**: `(union\s+select|drop\s+table|delete\s+from)` -- **XSS**: `]*>.*?` -- **Sensitive Keywords**: `password`, `secret`, `token`, `admin`, `root` - -## API Reference - -### Main Functions - -- `NewGuardrailEngine(config *GuardrailConfig) *GuardrailEngine` -- `ProcessResponse(response *MCPResponse) *GuardrailResult` -- `ProcessRequest(request *MCPRequest) *GuardrailResult` -- `AddSensitivePattern(pattern SensitiveDataPattern)` -- `AddContentFilter(filter ContentFilter)` -- `UpdateConfig(config *GuardrailConfig)` - -### Data Structures - -- `MCPRequest`: MCP request structure -- `MCPResponse`: MCP response structure -- `GuardrailResult`: Result of guardrail processing -- `SensitiveDataPattern`: Pattern for sensitive data detection -- `ContentFilter`: Content filtering rule -- `LogEntry`: Log entry structure - -## Testing - -Run the test suite: - -```bash -go test ./... -``` - -Run tests with coverage: - -```bash -go test -cover ./... -``` - -## Integration with MCP Proxy - -This library is designed to be integrated into MCP proxy servers. Here's a typical integration pattern: - -```go -// In your MCP proxy -func handleMCPResponse(originalResponse *MCPResponse) *MCPResponse { - result := guardrailEngine.ProcessResponse(originalResponse) - - if result.Blocked { - // Return error response - return &MCPResponse{ - ID: originalResponse.ID, - Error: &MCPError{ - Code: 403, - Message: result.BlockReason, - }, - } - } - - // Return sanitized response - return result.SanitizedResponse -} -``` - -## Contributing - -1. Fork the repository -2. Create a feature branch -3. Make your changes -4. Add tests for new functionality -5. Run the test suite -6. Submit a pull request - -## License - -This project is licensed under the MIT License - see the LICENSE file for details. - -## Security - -This library is designed for security but should be used as part of a comprehensive security strategy. Always: - -- Keep the library updated -- Review and customize patterns for your specific needs -- Monitor logs for security events -- Test thoroughly in your environment -- Consider additional security measures \ No newline at end of file diff --git a/apps/mcp-guardrails/cmd/demo/main.go b/apps/mcp-guardrails/cmd/demo/main.go deleted file mode 100644 index 2a532e6074..0000000000 --- a/apps/mcp-guardrails/cmd/demo/main.go +++ /dev/null @@ -1,228 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "time" - - "github.com/akto/mcp-guardrails" -) - -func main() { - fmt.Println("MCP Guardrails Library Demo") - fmt.Println("===========================") - - // Create a comprehensive guardrail configuration - config := &guardrails.GuardrailConfig{ - EnableDataSanitization: true, - SensitiveFields: []string{"password", "api_key", "secret", "token"}, - RedactionPatterns: []string{}, - - EnableContentFiltering: true, - BlockedKeywords: []string{"malicious", "dangerous", "exploit"}, - AllowedDomains: []string{"example.com", "api.example.com"}, - - EnableRateLimiting: true, - RateLimitConfig: guardrails.RateLimitConfig{ - RequestsPerMinute: 100, - BurstSize: 10, - WindowSize: time.Minute, - }, - - EnableInputValidation: true, - ValidationRules: map[string]string{ - "method": "required", - }, - - EnableOutputFiltering: true, - OutputFilters: []string{"block_sensitive"}, - - EnableLogging: true, - LogLevel: "INFO", - } - - // Create the guardrail engine - engine := guardrails.NewGuardrailEngine(config) - - // Demo 1: Data Sanitization - fmt.Println("\n1. Data Sanitization Demo") - fmt.Println("-------------------------") - - responseWithSensitiveData := &guardrails.MCPResponse{ - ID: "demo_1", - Result: json.RawMessage(`{ - "user": { - "name": "John Doe", - "email": "john.doe@example.com", - "password": "super_secret_password_123", - "credit_card": "1234-5678-9012-3456", - "ssn": "123-45-6789", - "api_key": "sk-1234567890abcdef1234567890abcdef" - }, - "account": { - "id": "ACC123456", - "balance": 1000.50, - "secret_token": "internal_secret_123" - } - }`), - } - - result := engine.ProcessResponse(responseWithSensitiveData) - - if result.Blocked { - fmt.Printf("❌ Response blocked: %s\n", result.BlockReason) - } else { - fmt.Println("βœ… Response processed successfully") - - if len(result.Warnings) > 0 { - fmt.Printf("⚠️ Warnings: %v\n", result.Warnings) - } - - if result.SanitizedResponse != nil { - fmt.Println("πŸ”’ Data sanitization applied") - // Show sanitized result - sanitizedBytes, _ := json.MarshalIndent(result.SanitizedResponse.Result, "", " ") - fmt.Printf("Sanitized result:\n%s\n", string(sanitizedBytes)) - } - } - - // Demo 2: Content Filtering - fmt.Println("\n2. Content Filtering Demo") - fmt.Println("-------------------------") - - // Add custom content filter - customFilter := guardrails.ContentFilter{ - Type: "keyword", - Pattern: "internal", - Action: "warn", - Description: "Internal information detected", - } - engine.AddContentFilter(customFilter) - - responseWithInternalInfo := &guardrails.MCPResponse{ - ID: "demo_2", - Result: json.RawMessage(`{ - "message": "This contains internal company information", - "status": "success" - }`), - } - - result = engine.ProcessResponse(responseWithInternalInfo) - - if result.Blocked { - fmt.Printf("❌ Response blocked: %s\n", result.BlockReason) - } else { - fmt.Println("βœ… Response processed successfully") - if len(result.Warnings) > 0 { - fmt.Printf("⚠️ Warnings: %v\n", result.Warnings) - } - } - - // Demo 3: Rate Limiting - fmt.Println("\n3. Rate Limiting Demo") - fmt.Println("---------------------") - - request := &guardrails.MCPRequest{ - ID: "demo_3", - Method: "tools/list", - Params: json.RawMessage(`{"include_hidden": true}`), - } - - // Simulate multiple requests - for i := 1; i <= 5; i++ { - result := engine.ProcessRequest(request) - if result.Blocked { - fmt.Printf("❌ Request %d blocked: %s\n", i, result.BlockReason) - } else { - fmt.Printf("βœ… Request %d allowed\n", i) - } - } - - // Demo 4: Input Validation - fmt.Println("\n4. Input Validation Demo") - fmt.Println("-------------------------") - - // Valid request - validRequest := &guardrails.MCPRequest{ - ID: "demo_4_valid", - Method: "tools/call", - Params: json.RawMessage(`{"tool": "file_read"}`), - } - - result = engine.ProcessRequest(validRequest) - if result.Blocked { - fmt.Printf("❌ Valid request blocked: %s\n", result.BlockReason) - } else { - fmt.Println("βœ… Valid request accepted") - } - - // Invalid request (empty method) - invalidRequest := &guardrails.MCPRequest{ - ID: "demo_4_invalid", - Method: "", - Params: json.RawMessage(`{"tool": "file_read"}`), - } - - result = engine.ProcessRequest(invalidRequest) - if result.Blocked { - fmt.Printf("❌ Invalid request correctly blocked: %s\n", result.BlockReason) - } else { - fmt.Println("⚠️ Invalid request should have been blocked") - } - - // Demo 5: Custom Patterns - fmt.Println("\n5. Custom Patterns Demo") - fmt.Println("-----------------------") - - // Add custom sensitive pattern - customPattern := guardrails.SensitiveDataPattern{ - Name: "employee_id", - Pattern: `\bEMP\d{6}\b`, - Replacement: "***REDACTED_EMPLOYEE_ID***", - Description: "Employee ID format", - } - engine.AddSensitivePattern(customPattern) - - responseWithEmployeeData := &guardrails.MCPResponse{ - ID: "demo_5", - Result: json.RawMessage(`{ - "employee": { - "name": "Jane Smith", - "id": "EMP123456", - "department": "Engineering" - } - }`), - } - - result = engine.ProcessResponse(responseWithEmployeeData) - - if result.Blocked { - fmt.Printf("❌ Response blocked: %s\n", result.BlockReason) - } else { - fmt.Println("βœ… Response processed successfully") - - if result.SanitizedResponse != nil { - fmt.Println("πŸ”’ Custom pattern applied") - sanitizedBytes, _ := json.MarshalIndent(result.SanitizedResponse.Result, "", " ") - fmt.Printf("Sanitized result:\n%s\n", string(sanitizedBytes)) - } - } - - // Demo 6: Logging - fmt.Println("\n6. Logging Demo") - fmt.Println("---------------") - - if len(result.Logs) > 0 { - fmt.Println("πŸ“ Log entries generated:") - for i, logEntry := range result.Logs { - fmt.Printf(" %d. [%s] %s - %s\n", - i+1, - logEntry.Level, - logEntry.Timestamp.Format("15:04:05"), - logEntry.Message) - } - } - - fmt.Println("\nπŸŽ‰ MCP Guardrails Demo Complete!") - fmt.Println("The library is ready to be integrated into your MCP proxy server.") -} \ No newline at end of file diff --git a/apps/mcp-guardrails/examples.go b/apps/mcp-guardrails/examples.go deleted file mode 100644 index a997aedb03..0000000000 --- a/apps/mcp-guardrails/examples.go +++ /dev/null @@ -1,175 +0,0 @@ -// Package guardrails provides MCP guardrail functionality -package guardrails - -import ( - "encoding/json" - "fmt" - "log" - "time" -) - -// Example usage of the MCP Guardrails library -func ExampleUsage() { - // Create a guardrail configuration - config := &GuardrailConfig{ - EnableDataSanitization: true, - SensitiveFields: []string{"password", "api_key", "secret"}, - RedactionPatterns: []string{}, - - EnableContentFiltering: true, - BlockedKeywords: []string{"malicious", "dangerous"}, - AllowedDomains: []string{"example.com", "api.example.com"}, - - EnableRateLimiting: true, - RateLimitConfig: RateLimitConfig{ - RequestsPerMinute: 100, - BurstSize: 10, - WindowSize: time.Minute, - }, - - EnableInputValidation: true, - ValidationRules: map[string]string{ - "method": "required", - }, - - EnableOutputFiltering: true, - OutputFilters: []string{"block_sensitive"}, - - EnableLogging: true, - LogLevel: "INFO", - } - - // Create the guardrail engine - engine := NewGuardrailEngine(config) - - // Example 1: Process a response with sensitive data - exampleResponse := &MCPResponse{ - ID: "1", - Result: json.RawMessage(`{ - "user": { - "name": "John Doe", - "email": "john.doe@example.com", - "password": "secret123", - "credit_card": "1234-5678-9012-3456" - } - }`), - } - - result := engine.ProcessResponse(exampleResponse) - fmt.Printf("Response processed: Blocked=%v, Warnings=%v\n", - result.Blocked, result.Warnings) - - // Example 2: Process a request with validation - exampleRequest := &MCPRequest{ - ID: "2", - Method: "tools/list", - Params: json.RawMessage(`{"include_hidden": true}`), - } - - requestResult := engine.ProcessRequest(exampleRequest) - fmt.Printf("Request processed: Blocked=%v, Warnings=%v\n", - requestResult.Blocked, requestResult.Warnings) - - // Example 3: Add custom sensitive pattern - customPattern := SensitiveDataPattern{ - Name: "custom_id", - Pattern: `\b[A-Z]{2}\d{6}\b`, - Replacement: "***REDACTED_ID***", - Description: "Custom ID format", - } - engine.AddSensitivePattern(customPattern) - - // Example 4: Add custom content filter - customFilter := ContentFilter{ - Type: "keyword", - Pattern: "internal", - Action: "warn", - Description: "Internal information detected", - } - engine.AddContentFilter(customFilter) -} - -// ExampleWithRealData demonstrates processing real MCP data -func ExampleWithRealData() { - config := &GuardrailConfig{ - EnableDataSanitization: true, - EnableContentFiltering: true, - EnableRateLimiting: true, - EnableLogging: true, - LogLevel: "DEBUG", - RateLimitConfig: RateLimitConfig{ - RequestsPerMinute: 60, - BurstSize: 5, - WindowSize: time.Minute, - }, - } - - engine := NewGuardrailEngine(config) - - // Simulate MCP response with sensitive data - responseData := map[string]interface{}{ - "tools": []map[string]interface{}{ - { - "name": "file_read", - "description": "Read file contents", - "parameters": map[string]interface{}{ - "path": "/etc/passwd", - }, - }, - }, - "user_info": map[string]interface{}{ - "id": "US123456", - "email": "admin@company.com", - "password": "super_secret_password", - "api_key": "sk-1234567890abcdef", - "ssn": "123-45-6789", - }, - } - - responseBytes, _ := json.Marshal(responseData) - response := &MCPResponse{ - ID: "req_123", - Result: responseBytes, - } - - // Process the response - result := engine.ProcessResponse(response) - - log.Printf("Processing complete:") - log.Printf(" Blocked: %v", result.Blocked) - log.Printf(" Block reason: %s", result.BlockReason) - log.Printf(" Warnings: %v", result.Warnings) - - if result.SanitizedResponse != nil { - log.Printf(" Response sanitized successfully") - } -} - -// ExampleRateLimiting demonstrates rate limiting functionality -func ExampleRateLimiting() { - config := &GuardrailConfig{ - EnableRateLimiting: true, - RateLimitConfig: RateLimitConfig{ - RequestsPerMinute: 10, // Very low for testing - BurstSize: 2, - WindowSize: time.Minute, - }, - } - - engine := NewGuardrailEngine(config) - request := &MCPRequest{ - ID: "test", - Method: "test/method", - } - - // Simulate multiple requests - for i := 0; i < 5; i++ { - result := engine.ProcessRequest(request) - fmt.Printf("Request %d: Allowed=%v, Blocked=%v\n", - i+1, !result.Blocked, result.Blocked) - - if result.Blocked { - fmt.Printf(" Reason: %s\n", result.BlockReason) - } - } -} \ No newline at end of file diff --git a/apps/mcp-guardrails/go.mod b/apps/mcp-guardrails/go.mod deleted file mode 100644 index 853d08efe4..0000000000 --- a/apps/mcp-guardrails/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/akto/mcp-guardrails - -go 1.21 - -require github.com/google/uuid v1.4.0 diff --git a/apps/mcp-guardrails/go.sum b/apps/mcp-guardrails/go.sum deleted file mode 100644 index fef9ecd232..0000000000 --- a/apps/mcp-guardrails/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/apps/mcp-guardrails/guardrails.go b/apps/mcp-guardrails/guardrails.go deleted file mode 100644 index 2aaf5e94e1..0000000000 --- a/apps/mcp-guardrails/guardrails.go +++ /dev/null @@ -1,362 +0,0 @@ -package guardrails - -import ( - "encoding/json" - "fmt" - "log" - "regexp" - "strings" - "sync" - "time" - "unicode" -) - -// GuardrailEngine is the main engine that applies all guardrails -type GuardrailEngine struct { - config *GuardrailConfig - patterns []SensitiveDataPattern - filters []ContentFilter - rateLimiter *RateLimiter - mutex sync.RWMutex -} - -// NewGuardrailEngine creates a new guardrail engine with the given configuration -func NewGuardrailEngine(config *GuardrailConfig) *GuardrailEngine { - engine := &GuardrailEngine{ - config: config, - patterns: getDefaultSensitivePatterns(), - filters: getDefaultContentFilters(), - } - - if config.EnableRateLimiting { - engine.rateLimiter = NewRateLimiter(config.RateLimitConfig) - } - - return engine -} - -// ProcessResponse applies all configured guardrails to an MCP response -func (g *GuardrailEngine) ProcessResponse(response *MCPResponse) *GuardrailResult { - result := &GuardrailResult{ - SanitizedResponse: response, - Blocked: false, - Warnings: []string{}, - Logs: []LogEntry{}, - } - - g.mutex.RLock() - defer g.mutex.RUnlock() - - // Apply data sanitization - if g.config.EnableDataSanitization { - g.applyDataSanitization(result) - } - - // Apply content filtering - if g.config.EnableContentFiltering { - g.applyContentFiltering(result) - } - - // Apply output filtering - if g.config.EnableOutputFiltering { - g.applyOutputFiltering(result) - } - - // Log the operation if enabled - if g.config.EnableLogging { - g.logOperation(result) - } - - return result -} - -// ProcessRequest applies guardrails to an MCP request -func (g *GuardrailEngine) ProcessRequest(request *MCPRequest) *GuardrailResult { - result := &GuardrailResult{ - Blocked: false, - Warnings: []string{}, - Logs: []LogEntry{}, - } - - g.mutex.RLock() - defer g.mutex.RUnlock() - - // Apply rate limiting - if g.config.EnableRateLimiting && g.rateLimiter != nil { - if !g.rateLimiter.Allow() { - result.Blocked = true - result.BlockReason = "Rate limit exceeded" - return result - } - } - - // Apply input validation - if g.config.EnableInputValidation { - g.applyInputValidation(request, result) - } - - // Apply content filtering to request - if g.config.EnableContentFiltering { - g.applyRequestContentFiltering(request, result) - } - - // Log the operation if enabled - if g.config.EnableLogging { - g.logOperation(result) - } - - return result -} - -// applyDataSanitization sanitizes sensitive data in the response -func (g *GuardrailEngine) applyDataSanitization(result *GuardrailResult) { - if result.SanitizedResponse == nil { - return - } - - // Convert response to string for processing - responseBytes, err := json.Marshal(result.SanitizedResponse) - if err != nil { - result.Warnings = append(result.Warnings, "Failed to marshal response for sanitization") - return - } - - responseStr := string(responseBytes) - - // Apply sensitive data patterns - for _, pattern := range g.patterns { - re, err := regexp.Compile(pattern.Pattern) - if err != nil { - result.Warnings = append(result.Warnings, fmt.Sprintf("Invalid pattern %s: %v", pattern.Name, err)) - continue - } - - responseStr = re.ReplaceAllString(responseStr, pattern.Replacement) - } - - // Apply custom sensitive fields - for _, field := range g.config.SensitiveFields { - fieldPattern := fmt.Sprintf(`"%s"\s*:\s*"([^"]*)"`, regexp.QuoteMeta(field)) - re, err := regexp.Compile(fieldPattern) - if err != nil { - continue - } - responseStr = re.ReplaceAllString(responseStr, fmt.Sprintf(`"%s": "***REDACTED***"`, field)) - } - - // Unmarshal back to response - var sanitizedResponse MCPResponse - if err := json.Unmarshal([]byte(responseStr), &sanitizedResponse); err != nil { - result.Warnings = append(result.Warnings, "Failed to unmarshal sanitized response") - return - } - - result.SanitizedResponse = &sanitizedResponse - result.Logs = append(result.Logs, LogEntry{ - Timestamp: time.Now(), - Level: "INFO", - Message: "Data sanitization applied", - }) -} - -// applyContentFiltering filters content based on configured rules -func (g *GuardrailEngine) applyContentFiltering(result *GuardrailResult) { - if result.SanitizedResponse == nil { - return - } - - responseBytes, err := json.Marshal(result.SanitizedResponse) - if err != nil { - return - } - - responseStr := string(responseBytes) - - for _, filter := range g.filters { - var matched bool - switch filter.Type { - case "keyword": - matched = strings.Contains(strings.ToLower(responseStr), strings.ToLower(filter.Pattern)) - case "regex": - re, err := regexp.Compile(filter.Pattern) - if err != nil { - continue - } - matched = re.MatchString(responseStr) - } - - if matched { - switch filter.Action { - case "block": - result.Blocked = true - result.BlockReason = fmt.Sprintf("Content blocked: %s", filter.Description) - return - case "warn": - result.Warnings = append(result.Warnings, fmt.Sprintf("Content warning: %s", filter.Description)) - case "sanitize": - // Apply sanitization based on filter - responseStr = g.sanitizeContent(responseStr, filter) - } - } - } - - // Update response if content was sanitized - if responseStr != string(responseBytes) { - var sanitizedResponse MCPResponse - if err := json.Unmarshal([]byte(responseStr), &sanitizedResponse); err == nil { - result.SanitizedResponse = &sanitizedResponse - } - } -} - -// applyRequestContentFiltering filters content in requests -func (g *GuardrailEngine) applyRequestContentFiltering(request *MCPRequest, result *GuardrailResult) { - requestBytes, err := json.Marshal(request) - if err != nil { - return - } - - requestStr := string(requestBytes) - - for _, filter := range g.filters { - var matched bool - switch filter.Type { - case "keyword": - matched = strings.Contains(strings.ToLower(requestStr), strings.ToLower(filter.Pattern)) - case "regex": - re, err := regexp.Compile(filter.Pattern) - if err != nil { - continue - } - matched = re.MatchString(requestStr) - } - - if matched { - switch filter.Action { - case "block": - result.Blocked = true - result.BlockReason = fmt.Sprintf("Request blocked: %s", filter.Description) - return - case "warn": - result.Warnings = append(result.Warnings, fmt.Sprintf("Request warning: %s", filter.Description)) - } - } - } -} - -// applyInputValidation validates input parameters -func (g *GuardrailEngine) applyInputValidation(request *MCPRequest, result *GuardrailResult) { - // Validate method - if request.Method == "" { - result.Blocked = true - result.BlockReason = "Empty method not allowed" - return - } - - // Apply custom validation rules - for field, rule := range g.config.ValidationRules { - if err := g.validateField(request, field, rule); err != nil { - result.Blocked = true - result.BlockReason = fmt.Sprintf("Validation failed for %s: %v", field, err) - return - } - } -} - -// applyOutputFiltering filters output based on configured rules -func (g *GuardrailEngine) applyOutputFiltering(result *GuardrailResult) { - if result.SanitizedResponse == nil { - return - } - - for _, filter := range g.config.OutputFilters { - // Apply output filters (simplified implementation) - if strings.Contains(strings.ToLower(filter), "block") { - result.Warnings = append(result.Warnings, fmt.Sprintf("Output filter applied: %s", filter)) - } - } -} - -// sanitizeContent applies content sanitization based on filter rules -func (g *GuardrailEngine) sanitizeContent(content string, filter ContentFilter) string { - switch filter.Type { - case "keyword": - return strings.ReplaceAll(content, filter.Pattern, "***SANITIZED***") - case "regex": - re, err := regexp.Compile(filter.Pattern) - if err != nil { - return content - } - return re.ReplaceAllString(content, "***SANITIZED***") - default: - return content - } -} - -// validateField validates a specific field based on the given rule -func (g *GuardrailEngine) validateField(request *MCPRequest, field, rule string) error { - // Simplified validation - in a real implementation, you'd have more sophisticated validation - switch rule { - case "required": - if request.Method == "" { - return fmt.Errorf("field %s is required", field) - } - case "alphanumeric": - if !isAlphanumeric(request.Method) { - return fmt.Errorf("field %s must be alphanumeric", field) - } - } - return nil -} - -// isAlphanumeric checks if a string contains only alphanumeric characters -func isAlphanumeric(s string) bool { - for _, r := range s { - if !unicode.IsLetter(r) && !unicode.IsDigit(r) { - return false - } - } - return true -} - -// logOperation logs the guardrail operation -func (g *GuardrailEngine) logOperation(result *GuardrailResult) { - logEntry := LogEntry{ - Timestamp: time.Now(), - Level: g.config.LogLevel, - Message: "Guardrail operation completed", - Data: map[string]interface{}{ - "blocked": result.Blocked, - "warnings": result.Warnings, - "block_reason": result.BlockReason, - }, - } - - result.Logs = append(result.Logs, logEntry) - - // Also log to standard logger if configured - if g.config.LogLevel == "DEBUG" { - log.Printf("Guardrail: %+v", logEntry) - } -} - -// AddSensitivePattern adds a new sensitive data pattern -func (g *GuardrailEngine) AddSensitivePattern(pattern SensitiveDataPattern) { - g.mutex.Lock() - defer g.mutex.Unlock() - g.patterns = append(g.patterns, pattern) -} - -// AddContentFilter adds a new content filter -func (g *GuardrailEngine) AddContentFilter(filter ContentFilter) { - g.mutex.Lock() - defer g.mutex.Unlock() - g.filters = append(g.filters, filter) -} - -// UpdateConfig updates the guardrail configuration -func (g *GuardrailEngine) UpdateConfig(config *GuardrailConfig) { - g.mutex.Lock() - defer g.mutex.Unlock() - g.config = config -} \ No newline at end of file diff --git a/apps/mcp-guardrails/guardrails_test.go b/apps/mcp-guardrails/guardrails_test.go deleted file mode 100644 index 9e93b2866d..0000000000 --- a/apps/mcp-guardrails/guardrails_test.go +++ /dev/null @@ -1,345 +0,0 @@ -package guardrails - -import ( - "encoding/json" - "testing" - "time" -) - -func TestNewGuardrailEngine(t *testing.T) { - config := &GuardrailConfig{ - EnableDataSanitization: true, - EnableContentFiltering: true, - EnableRateLimiting: true, - EnableLogging: true, - } - - engine := NewGuardrailEngine(config) - if engine == nil { - t.Fatal("Expected engine to be created") - } - - if engine.config != config { - t.Error("Expected config to be set") - } - - if len(engine.patterns) == 0 { - t.Error("Expected default patterns to be loaded") - } - - if len(engine.filters) == 0 { - t.Error("Expected default filters to be loaded") - } - - if engine.rateLimiter == nil { - t.Error("Expected rate limiter to be created when enabled") - } -} - -func TestDataSanitization(t *testing.T) { - config := &GuardrailConfig{ - EnableDataSanitization: true, - SensitiveFields: []string{"password", "secret"}, - } - - engine := NewGuardrailEngine(config) - - // Test response with sensitive data - responseData := map[string]interface{}{ - "user": map[string]interface{}{ - "name": "John Doe", - "email": "john@example.com", - "password": "secret123", - "ssn": "123-45-6789", - }, - "credentials": map[string]interface{}{ - "password": "secret123", - }, - } - - responseBytes, _ := json.Marshal(responseData) - response := &MCPResponse{ - ID: "test", - Result: responseBytes, - } - - result := engine.ProcessResponse(response) - - if result.Blocked { - t.Error("Expected response to not be blocked") - } - - if result.SanitizedResponse == nil { - t.Fatal("Expected sanitized response") - } - - // Check if sensitive data was redacted - sanitizedBytes, _ := json.Marshal(result.SanitizedResponse) - sanitizedStr := string(sanitizedBytes) - - if !contains(sanitizedStr, "***REDACTED***") { - t.Error("Expected password to be redacted") - } - - if !contains(sanitizedStr, "***REDACTED_SSN***") { - t.Error("Expected SSN to be redacted") - } - - if contains(sanitizedStr, "secret123") { - t.Error("Expected original password to be removed") - } -} - -func TestContentFiltering(t *testing.T) { - config := &GuardrailConfig{ - EnableContentFiltering: true, - } - - engine := NewGuardrailEngine(config) - - // Add a blocking filter - blockingFilter := ContentFilter{ - Type: "keyword", - Pattern: "malicious", - Action: "block", - Description: "Malicious content detected", - } - engine.AddContentFilter(blockingFilter) - - // Test response with blocked content - responseData := map[string]interface{}{ - "message": "This is malicious content", - } - - responseBytes, _ := json.Marshal(responseData) - response := &MCPResponse{ - ID: "test", - Result: responseBytes, - } - - result := engine.ProcessResponse(response) - - if !result.Blocked { - t.Error("Expected response to be blocked") - } - - if result.BlockReason == "" { - t.Error("Expected block reason to be set") - } - - // Test response with warning content - warningFilter := ContentFilter{ - Type: "keyword", - Pattern: "warning", - Action: "warn", - Description: "Warning content detected", - } - engine.AddContentFilter(warningFilter) - - responseData = map[string]interface{}{ - "message": "This is a warning message", - } - - responseBytes, _ = json.Marshal(responseData) - response = &MCPResponse{ - ID: "test", - Result: responseBytes, - } - - result = engine.ProcessResponse(response) - - if result.Blocked { - t.Error("Expected response to not be blocked") - } - - if len(result.Warnings) == 0 { - t.Error("Expected warnings to be generated") - } -} - -func TestRateLimiting(t *testing.T) { - config := &GuardrailConfig{ - EnableRateLimiting: true, - RateLimitConfig: RateLimitConfig{ - RequestsPerMinute: 2, - BurstSize: 1, - WindowSize: time.Minute, - }, - } - - engine := NewGuardrailEngine(config) - - request := &MCPRequest{ - ID: "test", - Method: "test/method", - } - - // First request should be allowed - result1 := engine.ProcessRequest(request) - if result1.Blocked { - t.Error("Expected first request to be allowed") - } - - // Second request should be blocked (burst limit) - result2 := engine.ProcessRequest(request) - if !result2.Blocked { - t.Error("Expected second request to be blocked") - } - - if result2.BlockReason != "Rate limit exceeded" { - t.Errorf("Expected rate limit reason, got: %s", result2.BlockReason) - } -} - -func TestInputValidation(t *testing.T) { - config := &GuardrailConfig{ - EnableInputValidation: true, - ValidationRules: map[string]string{ - "method": "required", - }, - } - - engine := NewGuardrailEngine(config) - - // Test valid request - validRequest := &MCPRequest{ - ID: "test", - Method: "tools/list", - } - - result := engine.ProcessRequest(validRequest) - if result.Blocked { - t.Error("Expected valid request to be allowed") - } - - // Test invalid request (empty method) - invalidRequest := &MCPRequest{ - ID: "test", - Method: "", - } - - result = engine.ProcessRequest(invalidRequest) - if !result.Blocked { - t.Error("Expected invalid request to be blocked") - } - - if result.BlockReason == "" { - t.Error("Expected block reason to be set") - } -} - -func TestCustomPatterns(t *testing.T) { - config := &GuardrailConfig{ - EnableDataSanitization: true, - } - - engine := NewGuardrailEngine(config) - - // Add custom pattern - customPattern := SensitiveDataPattern{ - Name: "custom_id", - Pattern: `\b[A-Z]{2}\d{6}\b`, - Replacement: "***REDACTED_ID***", - Description: "Custom ID format", - } - engine.AddSensitivePattern(customPattern) - - // Test response with custom pattern - responseData := map[string]interface{}{ - "user_id": "AB123456", - "name": "John Doe", - } - - responseBytes, _ := json.Marshal(responseData) - response := &MCPResponse{ - ID: "test", - Result: responseBytes, - } - - result := engine.ProcessResponse(response) - - if result.SanitizedResponse == nil { - t.Fatal("Expected sanitized response") - } - - sanitizedBytes, _ := json.Marshal(result.SanitizedResponse.Result) - sanitizedStr := string(sanitizedBytes) - - if !contains(sanitizedStr, "***REDACTED_ID***") { - t.Error("Expected custom ID to be redacted") - } - - if contains(sanitizedStr, "AB123456") { - t.Error("Expected original ID to be removed") - } -} - -func TestLogging(t *testing.T) { - config := &GuardrailConfig{ - EnableLogging: true, - LogLevel: "INFO", - } - - engine := NewGuardrailEngine(config) - - response := &MCPResponse{ - ID: "test", - Result: json.RawMessage(`{"message": "test"}`), - } - - result := engine.ProcessResponse(response) - - if len(result.Logs) == 0 { - t.Error("Expected logs to be generated") - } - - // Check if log entry has required fields - logEntry := result.Logs[0] - if logEntry.Timestamp.IsZero() { - t.Error("Expected timestamp to be set") - } - - if logEntry.Level == "" { - t.Error("Expected log level to be set") - } - - if logEntry.Message == "" { - t.Error("Expected log message to be set") - } -} - -func TestUpdateConfig(t *testing.T) { - config1 := &GuardrailConfig{ - EnableDataSanitization: true, - EnableLogging: true, - } - - engine := NewGuardrailEngine(config1) - - // Update configuration - config2 := &GuardrailConfig{ - EnableDataSanitization: false, - EnableLogging: false, - } - - engine.UpdateConfig(config2) - - if engine.config != config2 { - t.Error("Expected config to be updated") - } -} - -// Helper function to check if string contains substring -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || - (len(s) > len(substr) && (s[:len(substr)] == substr || - s[len(s)-len(substr):] == substr || - func() bool { - for i := 1; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false - }()))) -} \ No newline at end of file diff --git a/apps/mcp-guardrails/patterns.go b/apps/mcp-guardrails/patterns.go deleted file mode 100644 index accbac402b..0000000000 --- a/apps/mcp-guardrails/patterns.go +++ /dev/null @@ -1,115 +0,0 @@ -package guardrails - -// getDefaultSensitivePatterns returns default patterns for detecting sensitive data -func getDefaultSensitivePatterns() []SensitiveDataPattern { - return []SensitiveDataPattern{ - { - Name: "credit_card", - Pattern: `\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b`, - Replacement: "***REDACTED_CREDIT_CARD***", - Description: "Credit card numbers", - }, - { - Name: "ssn", - Pattern: `\b\d{3}-\d{2}-\d{4}\b`, - Replacement: "***REDACTED_SSN***", - Description: "Social Security Numbers", - }, - { - Name: "email", - Pattern: `\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`, - Replacement: "***REDACTED_EMAIL***", - Description: "Email addresses", - }, - { - Name: "phone", - Pattern: `\b\d{3}[-.]?\d{3}[-.]?\d{4}\b`, - Replacement: "***REDACTED_PHONE***", - Description: "Phone numbers", - }, - { - Name: "api_key", - Pattern: `(api[_-]?key|access[_-]?token|secret[_-]?key)\s*[:=]\s*["']?[A-Za-z0-9]{20,}["']?`, - Replacement: "***REDACTED_API_KEY***", - Description: "API keys and tokens", - }, - { - Name: "password", - Pattern: `"password"\s*:\s*"[^"]*"`, - Replacement: `"password": "***REDACTED_PASSWORD***"`, - Description: "Passwords", - }, - { - Name: "ip_address", - Pattern: `\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`, - Replacement: "***REDACTED_IP***", - Description: "IP addresses", - }, - { - Name: "aws_access_key", - Pattern: `AKIA[0-9A-Z]{16}`, - Replacement: "***REDACTED_AWS_KEY***", - Description: "AWS access keys", - }, - { - Name: "private_key", - Pattern: `-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----`, - Replacement: "***REDACTED_PRIVATE_KEY***", - Description: "Private keys", - }, - } -} - -// getDefaultContentFilters returns default content filtering rules -func getDefaultContentFilters() []ContentFilter { - return []ContentFilter{ - { - Type: "keyword", - Pattern: "password", - Action: "warn", - Description: "Password field detected", - }, - { - Type: "keyword", - Pattern: "secret", - Action: "warn", - Description: "Secret field detected", - }, - { - Type: "keyword", - Pattern: "token", - Action: "warn", - Description: "Token field detected", - }, - { - Type: "regex", - Pattern: `(eval|exec|system|shell_exec)`, - Action: "block", - Description: "Potentially dangerous code execution", - }, - { - Type: "regex", - Pattern: `(union\s+select|drop\s+table|delete\s+from)`, - Action: "block", - Description: "SQL injection attempt", - }, - { - Type: "regex", - Pattern: `]*>.*?`, - Action: "block", - Description: "XSS attempt", - }, - { - Type: "keyword", - Pattern: "admin", - Action: "warn", - Description: "Admin-related content", - }, - { - Type: "keyword", - Pattern: "root", - Action: "warn", - Description: "Root access content", - }, - } -} \ No newline at end of file diff --git a/apps/mcp-guardrails/rate_limiter.go b/apps/mcp-guardrails/rate_limiter.go deleted file mode 100644 index d4f5351678..0000000000 --- a/apps/mcp-guardrails/rate_limiter.go +++ /dev/null @@ -1,71 +0,0 @@ -package guardrails - -import ( - "sync" - "time" -) - -// RateLimiter implements a simple token bucket rate limiter -type RateLimiter struct { - config RateLimitConfig - tokens int - lastRefill time.Time - mutex sync.Mutex -} - -// NewRateLimiter creates a new rate limiter with the given configuration -func NewRateLimiter(config RateLimitConfig) *RateLimiter { - return &RateLimiter{ - config: config, - tokens: config.BurstSize, - lastRefill: time.Now(), - } -} - -// Allow checks if a request is allowed based on rate limiting -func (r *RateLimiter) Allow() bool { - r.mutex.Lock() - defer r.mutex.Unlock() - - now := time.Now() - - // Refill tokens based on time elapsed - timeElapsed := now.Sub(r.lastRefill) - tokensToAdd := int(timeElapsed.Minutes()) * r.config.RequestsPerMinute / 60 - - if tokensToAdd > 0 { - r.tokens = min(r.config.BurstSize, r.tokens+tokensToAdd) - r.lastRefill = now - } - - // Check if we have tokens available - if r.tokens > 0 { - r.tokens-- - return true - } - - return false -} - -// min returns the minimum of two integers -func min(a, b int) int { - if a < b { - return a - } - return b -} - -// GetRemainingTokens returns the number of remaining tokens -func (r *RateLimiter) GetRemainingTokens() int { - r.mutex.Lock() - defer r.mutex.Unlock() - return r.tokens -} - -// Reset resets the rate limiter -func (r *RateLimiter) Reset() { - r.mutex.Lock() - defer r.mutex.Unlock() - r.tokens = r.config.BurstSize - r.lastRefill = time.Now() -} \ No newline at end of file diff --git a/apps/mcp-guardrails/types.go b/apps/mcp-guardrails/types.go deleted file mode 100644 index b6e3b7cba4..0000000000 --- a/apps/mcp-guardrails/types.go +++ /dev/null @@ -1,98 +0,0 @@ -package guardrails - -import ( - "encoding/json" - "time" -) - -// MCPRequest represents an MCP server request -type MCPRequest struct { - ID string `json:"id"` - Method string `json:"method"` - Params json.RawMessage `json:"params,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -// MCPResponse represents an MCP server response -type MCPResponse struct { - ID string `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *MCPError `json:"error,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -// MCPError represents an MCP error response -type MCPError struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data,omitempty"` -} - -// GuardrailConfig holds configuration for different types of guardrails -type GuardrailConfig struct { - // Data Sanitization - EnableDataSanitization bool `json:"enable_data_sanitization"` - SensitiveFields []string `json:"sensitive_fields"` - RedactionPatterns []string `json:"redaction_patterns"` - - // Content Filtering - EnableContentFiltering bool `json:"enable_content_filtering"` - BlockedKeywords []string `json:"blocked_keywords"` - AllowedDomains []string `json:"allowed_domains"` - - // Rate Limiting - EnableRateLimiting bool `json:"enable_rate_limiting"` - RateLimitConfig RateLimitConfig `json:"rate_limit_config"` - - // Input Validation - EnableInputValidation bool `json:"enable_input_validation"` - ValidationRules map[string]string `json:"validation_rules"` - - // Output Filtering - EnableOutputFiltering bool `json:"enable_output_filtering"` - OutputFilters []string `json:"output_filters"` - - // Logging - EnableLogging bool `json:"enable_logging"` - LogLevel string `json:"log_level"` -} - -// RateLimitConfig holds rate limiting configuration -type RateLimitConfig struct { - RequestsPerMinute int `json:"requests_per_minute"` - BurstSize int `json:"burst_size"` - WindowSize time.Duration `json:"window_size"` -} - -// GuardrailResult represents the result of applying guardrails -type GuardrailResult struct { - SanitizedResponse *MCPResponse `json:"sanitized_response,omitempty"` - Blocked bool `json:"blocked"` - BlockReason string `json:"block_reason,omitempty"` - Warnings []string `json:"warnings,omitempty"` - Logs []LogEntry `json:"logs,omitempty"` -} - -// LogEntry represents a log entry for guardrail operations -type LogEntry struct { - Timestamp time.Time `json:"timestamp"` - Level string `json:"level"` - Message string `json:"message"` - Data any `json:"data,omitempty"` -} - -// SensitiveDataPattern represents patterns for detecting sensitive data -type SensitiveDataPattern struct { - Name string `json:"name"` - Pattern string `json:"pattern"` - Replacement string `json:"replacement"` - Description string `json:"description"` -} - -// ContentFilter represents content filtering rules -type ContentFilter struct { - Type string `json:"type"` // "keyword", "regex", "domain" - Pattern string `json:"pattern"` - Action string `json:"action"` // "block", "warn", "sanitize" - Description string `json:"description"` -} \ No newline at end of file diff --git a/go.work b/go.work new file mode 100644 index 0000000000..ddf2fde15f --- /dev/null +++ b/go.work @@ -0,0 +1,3 @@ +go 1.22 + +use ./libs/mcp-guardrails diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000000..1ac1a15a35 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,6 @@ +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/libs/dao/src/main/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDao.java b/libs/dao/src/main/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDao.java new file mode 100644 index 0000000000..a257752e16 --- /dev/null +++ b/libs/dao/src/main/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDao.java @@ -0,0 +1,61 @@ +package com.akto.dao.mcp; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.akto.dao.AccountsContextDao; +import com.akto.dto.mcp.MCPGuardrailConfig; +import com.akto.dto.mcp.MCPGuardrailConfigYamlParser; +import com.akto.dto.test_editor.YamlTemplate; +import com.mongodb.client.model.Filters; + +public class MCPGuardrailYamlTemplateDao extends AccountsContextDao { + + public static final MCPGuardrailYamlTemplateDao instance = new MCPGuardrailYamlTemplateDao(); + + public Map fetchMCPGuardrailConfig(boolean includeYamlContent) { + List yamlTemplates = MCPGuardrailYamlTemplateDao.instance.findAll(Filters.empty()); + return fetchMCPGuardrailConfig(includeYamlContent, yamlTemplates); + } + + public Map fetchMCPGuardrailConfig(boolean includeYamlContent, List yamlTemplates) { + Map guardrailConfigMap = new HashMap<>(); + for (YamlTemplate yamlTemplate : yamlTemplates) { + try { + if (yamlTemplate != null) { + MCPGuardrailConfig guardrailConfig = MCPGuardrailConfigYamlParser.parseTemplate(yamlTemplate.getContent()); + guardrailConfig.setAuthor(yamlTemplate.getAuthor()); + guardrailConfig.setCreatedAt(yamlTemplate.getCreatedAt()); + guardrailConfig.setUpdatedAt(yamlTemplate.getUpdatedAt()); + if (includeYamlContent) { + guardrailConfig.setContent(yamlTemplate.getContent()); + } + guardrailConfigMap.put(guardrailConfig.getId(), guardrailConfig); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + return guardrailConfigMap; + } + + public List fetchActiveTemplates() { + return MCPGuardrailYamlTemplateDao.instance.findAll(Filters.eq(YamlTemplate.INACTIVE, false)); + } + + public List fetchTemplatesByType(String type) { + //todo: shivam add filtering by type + return fetchActiveTemplates(); + } + + @Override + public String getCollName() { + return "mcp_guardrail_yaml_templates"; + } + + @Override + public Class getClassT() { + return YamlTemplate.class; + } +} diff --git a/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfig.java b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfig.java new file mode 100644 index 0000000000..cb5193d069 --- /dev/null +++ b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfig.java @@ -0,0 +1,96 @@ +package com.akto.dto.mcp; + +import com.akto.dto.test_editor.ConfigParserResult; +import com.akto.dto.test_editor.ExecutorConfigParserResult; +import com.akto.dto.test_editor.Info; + +public class MCPGuardrailConfig { + private String id; + public static final String ID = "id"; + private ConfigParserResult filter; + public static final String FILTER = "filter"; + public static final String CREATED_AT = "createdAt"; + private int createdAt; + public static final String UPDATED_AT = "updatedAt"; + private int updatedAt; + public static final String _AUTHOR = "author"; + private String author; + public static final String _CONTENT = "content"; + private String content; + public static final String _INFO = "info"; + private Info info; + private ExecutorConfigParserResult executor; + + public MCPGuardrailConfig(String id, ConfigParserResult filter) { + this.id = id; + this.filter = filter; + } + + public MCPGuardrailConfig() { + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public ConfigParserResult getFilter() { + return filter; + } + + public void setFilter(ConfigParserResult filter) { + this.filter = filter; + } + + public int getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(int createdAt) { + this.createdAt = createdAt; + } + + public int getUpdatedAt() { + return updatedAt; + } + + public void setUpdatedAt(int updatedAt) { + this.updatedAt = updatedAt; + } + + public String getAuthor() { + return author; + } + + public void setAuthor(String author) { + this.author = author; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public ExecutorConfigParserResult getExecutor() { + return executor; + } + + public void setExecutor(ExecutorConfigParserResult executor) { + this.executor = executor; + } + + + public Info getInfo() { + return info; + } + + public void setInfo(Info info) { + this.info = info; + } +} diff --git a/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParser.java b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParser.java new file mode 100644 index 0000000000..5dd6ded367 --- /dev/null +++ b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParser.java @@ -0,0 +1,66 @@ +package com.akto.dto.mcp; + +import java.util.Map; + +import com.akto.dao.test_editor.filter.ConfigParser; +import com.akto.dao.test_editor.info.InfoParser; +import com.akto.dto.test_editor.ConfigParserResult; +import com.akto.dto.test_editor.ExecutorConfigParserResult; +import com.akto.dto.test_editor.Info; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; + +public class MCPGuardrailConfigYamlParser { + + public static MCPGuardrailConfig parseTemplate(String content) throws Exception { + + ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + + Map config = mapper.readValue(content, new TypeReference>() { + }); + return parseConfig(config); + } + + public static MCPGuardrailConfig parseConfig(Map config) throws Exception { + + MCPGuardrailConfig guardrailConfig = null; + + String id = (String) config.get(MCPGuardrailConfig.ID); + if (id == null) { + return guardrailConfig; + } + + Object filterMap = config.get(MCPGuardrailConfig.FILTER); + if (filterMap == null) { + // todo: should not be null, throw error + guardrailConfig = new MCPGuardrailConfig(id, null); + } + + ConfigParser configParser = new ConfigParser(); + ConfigParserResult filters = configParser.parse(filterMap); + if (filters == null) { + // todo: throw error + guardrailConfig = new MCPGuardrailConfig(id, null); + } else { + guardrailConfig = new MCPGuardrailConfig(id, filters); + } + + + com.akto.dao.test_editor.executor.ConfigParser executorConfigParser = new com.akto.dao.test_editor.executor.ConfigParser(); + Object executionMap = config.get("execute"); + if(executionMap != null){ + ExecutorConfigParserResult executorConfigParserResult = executorConfigParser.parseConfigMap(executionMap); + guardrailConfig.setExecutor(executorConfigParserResult); + } + + InfoParser infoParser = new InfoParser(); + if (config.containsKey("info")) { + Info info = infoParser.parse(config.get("info")); + guardrailConfig.setInfo(info); + } + return guardrailConfig; + + } + +} diff --git a/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailType.java b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailType.java new file mode 100644 index 0000000000..d36a8c49cf --- /dev/null +++ b/libs/dao/src/main/java/com/akto/dto/mcp/MCPGuardrailType.java @@ -0,0 +1,8 @@ +package com.akto.dto.mcp; + +public enum MCPGuardrailType { + DATA_SANITIZATION, + CONTENT_FILTERING, + INPUT_VALIDATION, + OUTPUT_FILTERING +} diff --git a/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailIntegrationTest.java b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailIntegrationTest.java new file mode 100644 index 0000000000..1a128cea13 --- /dev/null +++ b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailIntegrationTest.java @@ -0,0 +1,170 @@ +package com.akto.dao.mcp; + +import com.akto.dto.mcp.MCPGuardrailConfig; +import com.akto.dto.mcp.MCPGuardrailConfigYamlParser; +import com.akto.dto.test_editor.YamlTemplate; +import org.junit.Test; +import static org.junit.Assert.*; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class MCPGuardrailIntegrationTest { + + @Test + public void testCompleteInsertAndFetchFlow() throws Exception { + // Create a test YAML template + String yamlContent = "id: \"integration-test-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"integration-test-value\""; + + // Create YamlTemplate object + YamlTemplate yamlTemplate = new YamlTemplate(); + yamlTemplate.setId("template-integration-001"); + yamlTemplate.setAuthor("integration-author"); + yamlTemplate.setCreatedAt(1234567890); + yamlTemplate.setUpdatedAt(1234567890); + yamlTemplate.setContent(yamlContent); + + // Test parsing the YAML content + MCPGuardrailConfig parsedConfig = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + assertNotNull(parsedConfig); + assertEquals("integration-test-001", parsedConfig.getId()); + assertNotNull(parsedConfig.getFilter()); + + // Test DAO functionality + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + List templates = Arrays.asList(yamlTemplate); + + // Test fetch with content + Map resultWithContent = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(1, resultWithContent.size()); + MCPGuardrailConfig fetchedConfig = resultWithContent.get("integration-test-001"); + + assertNotNull(fetchedConfig); + assertEquals("integration-test-001", fetchedConfig.getId()); + assertNotNull(fetchedConfig.getFilter()); + assertEquals("integration-author", fetchedConfig.getAuthor()); + assertEquals(1234567890, fetchedConfig.getCreatedAt()); + assertEquals(1234567890, fetchedConfig.getUpdatedAt()); + assertNotNull(fetchedConfig.getContent()); + assertTrue(fetchedConfig.getContent().contains("integration-test-001")); + + // Test fetch without content + Map resultWithoutContent = dao.fetchMCPGuardrailConfig(false, templates); + + assertEquals(1, resultWithoutContent.size()); + MCPGuardrailConfig fetchedConfigNoContent = resultWithoutContent.get("integration-test-001"); + + assertNotNull(fetchedConfigNoContent); + assertEquals("integration-test-001", fetchedConfigNoContent.getId()); + assertNull(fetchedConfigNoContent.getContent()); // Content should be null + } + + @Test + public void testMultipleTemplatesFlow() throws Exception { + // Create multiple test templates + YamlTemplate template1 = new YamlTemplate(); + template1.setId("template-1"); + template1.setAuthor("author1"); + template1.setCreatedAt(1234567890); + template1.setUpdatedAt(1234567890); + template1.setContent("id: \"guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"value1\""); + + YamlTemplate template2 = new YamlTemplate(); + template2.setId("template-2"); + template2.setAuthor("author2"); + template2.setCreatedAt(1234567891); + template2.setUpdatedAt(1234567891); + template2.setContent("id: \"guardrail-002\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"value2\""); + + YamlTemplate template3 = new YamlTemplate(); + template3.setId("template-3"); + template3.setAuthor("author3"); + template3.setCreatedAt(1234567892); + template3.setUpdatedAt(1234567892); + template3.setContent("id: \"guardrail-003\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"value3\""); + + List templates = Arrays.asList(template1, template2, template3); + + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + Map result = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(3, result.size()); + + // Verify first guardrail + MCPGuardrailConfig config1 = result.get("guardrail-001"); + assertNotNull(config1); + assertEquals("guardrail-001", config1.getId()); + assertNotNull(config1.getFilter()); + assertEquals("author1", config1.getAuthor()); + + // Verify second guardrail + MCPGuardrailConfig config2 = result.get("guardrail-002"); + assertNotNull(config2); + assertEquals("guardrail-002", config2.getId()); + assertNotNull(config2.getFilter()); + assertEquals("author2", config2.getAuthor()); + + // Verify third guardrail + MCPGuardrailConfig config3 = result.get("guardrail-003"); + assertNotNull(config3); + assertEquals("guardrail-003", config3.getId()); + assertNotNull(config3.getFilter()); + assertEquals("author3", config3.getAuthor()); + } + + + @Test + public void testErrorHandlingWithInvalidTemplates() { + // Create template with invalid YAML + YamlTemplate invalidTemplate = new YamlTemplate(); + invalidTemplate.setId("invalid-template"); + invalidTemplate.setAuthor("invalid-author"); + invalidTemplate.setCreatedAt(1234567890); + invalidTemplate.setUpdatedAt(1234567890); + invalidTemplate.setContent("invalid: yaml: content: ["); + + // Create template with null content + YamlTemplate nullContentTemplate = new YamlTemplate(); + nullContentTemplate.setId("null-content-template"); + nullContentTemplate.setAuthor("null-author"); + nullContentTemplate.setCreatedAt(1234567891); + nullContentTemplate.setUpdatedAt(1234567891); + nullContentTemplate.setContent(null); + + // Create template with empty content + YamlTemplate emptyContentTemplate = new YamlTemplate(); + emptyContentTemplate.setId("empty-content-template"); + emptyContentTemplate.setAuthor("empty-author"); + emptyContentTemplate.setCreatedAt(1234567892); + emptyContentTemplate.setUpdatedAt(1234567892); + emptyContentTemplate.setContent(""); + + List templates = Arrays.asList(invalidTemplate, nullContentTemplate, emptyContentTemplate); + + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + + // Should not throw exception, but should skip invalid templates + Map result = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(0, result.size()); // All invalid templates should be skipped + } +} diff --git a/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailTestRunner.java b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailTestRunner.java new file mode 100644 index 0000000000..a522831458 --- /dev/null +++ b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailTestRunner.java @@ -0,0 +1,159 @@ +package com.akto.dao.mcp; + +import com.akto.dto.mcp.MCPGuardrailConfig; +import com.akto.dto.mcp.MCPGuardrailConfigYamlParser; +import com.akto.dto.test_editor.YamlTemplate; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * Simple test runner to verify MCPGuardrailConfig functionality + * This can be run to test insert and fetch operations + */ +public class MCPGuardrailTestRunner { + + public static void main(String[] args) { + System.out.println("Starting MCPGuardrailConfig Test Runner..."); + + try { + testBasicYamlParsing(); + testDaoOperations(); + testErrorHandling(); + System.out.println("All tests passed successfully!"); + } catch (Exception e) { + System.err.println("Test failed: " + e.getMessage()); + e.printStackTrace(); + } + } + + public static void testBasicYamlParsing() throws Exception { + System.out.println("Testing basic YAML parsing..."); + + String yamlContent = "id: \"test-guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value\""; + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + if (config == null) { + throw new RuntimeException("Config should not be null"); + } + if (!"test-guardrail-001".equals(config.getId())) { + throw new RuntimeException("ID mismatch"); + } + if (config.getFilter() == null) { + throw new RuntimeException("Filter should not be null"); + } + + System.out.println("βœ“ Basic YAML parsing test passed"); + } + + public static void testDaoOperations() throws Exception { + System.out.println("Testing DAO operations..."); + + // Create test YamlTemplate + YamlTemplate template = new YamlTemplate(); + template.setId("template-1"); + template.setAuthor("test-author"); + template.setCreatedAt(1234567890); + template.setUpdatedAt(1234567890); + template.setContent("id: \"guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value\""); + + List templates = Arrays.asList(template); + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + + // Test fetch with content + Map resultWithContent = dao.fetchMCPGuardrailConfig(true, templates); + + if (resultWithContent.size() != 1) { + throw new RuntimeException("Expected 1 config, got " + resultWithContent.size()); + } + + MCPGuardrailConfig config = resultWithContent.get("guardrail-001"); + if (config == null) { + throw new RuntimeException("Config should not be null"); + } + if (!"guardrail-001".equals(config.getId())) { + throw new RuntimeException("ID mismatch"); + } + if (config.getFilter() == null) { + throw new RuntimeException("Filter should not be null"); + } + if (config.getContent() == null) { + throw new RuntimeException("Content should not be null when includeYamlContent=true"); + } + + // Test fetch without content + Map resultWithoutContent = dao.fetchMCPGuardrailConfig(false, templates); + + if (resultWithoutContent.size() != 1) { + throw new RuntimeException("Expected 1 config, got " + resultWithoutContent.size()); + } + + MCPGuardrailConfig configNoContent = resultWithoutContent.get("guardrail-001"); + if (configNoContent == null) { + throw new RuntimeException("Config should not be null"); + } + if (configNoContent.getContent() != null) { + throw new RuntimeException("Content should be null when includeYamlContent=false"); + } + + System.out.println("βœ“ DAO operations test passed"); + } + + public static void testErrorHandling() throws Exception { + System.out.println("Testing error handling..."); + + // Test invalid YAML + try { + MCPGuardrailConfigYamlParser.parseTemplate("invalid: yaml: content: ["); + throw new RuntimeException("Should have thrown exception for invalid YAML"); + } catch (Exception e) { + // Expected + } + + // Test empty YAML + try { + MCPGuardrailConfigYamlParser.parseTemplate(""); + throw new RuntimeException("Should have thrown exception for empty YAML"); + } catch (IllegalArgumentException e) { + // Expected + } + + // Test null YAML + try { + MCPGuardrailConfigYamlParser.parseTemplate(null); + throw new RuntimeException("Should have thrown exception for null YAML"); + } catch (IllegalArgumentException e) { + // Expected + } + + // Test invalid template in DAO + YamlTemplate invalidTemplate = new YamlTemplate(); + invalidTemplate.setId("invalid-template"); + invalidTemplate.setAuthor("invalid-author"); + invalidTemplate.setCreatedAt(1234567890); + invalidTemplate.setUpdatedAt(1234567890); + invalidTemplate.setContent("invalid: yaml: content: ["); + + List templates = Arrays.asList(invalidTemplate); + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + + // Should not throw exception, but should skip invalid templates + Map result = dao.fetchMCPGuardrailConfig(true, templates); + + if (result.size() != 0) { + throw new RuntimeException("Expected 0 configs for invalid template, got " + result.size()); + } + + System.out.println("βœ“ Error handling test passed"); + } +} diff --git a/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDaoTest.java b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDaoTest.java new file mode 100644 index 0000000000..eb57a76cfd --- /dev/null +++ b/libs/dao/src/test/java/com/akto/dao/mcp/MCPGuardrailYamlTemplateDaoTest.java @@ -0,0 +1,199 @@ +package com.akto.dao.mcp; + +import com.akto.dto.mcp.MCPGuardrailConfig; +import com.akto.dto.mcp.MCPGuardrailConfigYamlParser; +import com.akto.dto.test_editor.YamlTemplate; +import org.junit.Test; +import static org.junit.Assert.*; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class MCPGuardrailYamlTemplateDaoTest { + + @Test + public void testParseValidYamlTemplate() throws Exception { + String yamlContent = "id: \"test-guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value\""; + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + assertNotNull(config); + assertEquals("test-guardrail-001", config.getId()); + assertNotNull(config.getFilter()); + } + + @Test + public void testParseYamlWithDefaults() throws Exception { + String yamlContent = "id: \"minimal-guardrail\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"default-value\""; + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + assertNotNull(config); + assertEquals("minimal-guardrail", config.getId()); + assertNotNull(config.getFilter()); + } + + + @Test(expected = Exception.class) + public void testParseInvalidYaml() throws Exception { + String invalidYaml = "invalid: yaml: content: ["; + MCPGuardrailConfigYamlParser.parseTemplate(invalidYaml); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseEmptyYaml() throws Exception { + MCPGuardrailConfigYamlParser.parseTemplate(""); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseNullYaml() throws Exception { + MCPGuardrailConfigYamlParser.parseTemplate(null); + } + + @Test + public void testParseConfigWithMap() throws Exception { + java.util.Map configMap = new java.util.HashMap<>(); + configMap.put("id", "map-test-001"); + + java.util.Map filterMap = new java.util.HashMap<>(); + java.util.List predList = new java.util.ArrayList<>(); + java.util.Map predItem = new java.util.HashMap<>(); + java.util.List dataList = new java.util.ArrayList<>(); + dataList.add("test-data"); + predItem.put("data", dataList); + predList.add(predItem); + filterMap.put("pred", predList); + configMap.put("filter", filterMap); + + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseConfig(configMap); + + assertNotNull(config); + assertEquals("map-test-001", config.getId()); + assertNotNull(config.getFilter()); + } + + @Test + public void testParseConfigWithNullId() throws Exception { + java.util.Map configMap = new java.util.HashMap<>(); + configMap.put("name", "No ID Test"); + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseConfig(configMap); + + assertNull(config); + } + + @Test + public void testFetchMCPGuardrailConfigWithYamlTemplates() { + // Create test YamlTemplate objects + YamlTemplate template1 = new YamlTemplate(); + template1.setId("template-1"); + template1.setAuthor("author1"); + template1.setCreatedAt(1234567890); + template1.setUpdatedAt(1234567890); + template1.setContent("id: \"guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value-1\""); + + YamlTemplate template2 = new YamlTemplate(); + template2.setId("template-2"); + template2.setAuthor("author2"); + template2.setCreatedAt(1234567891); + template2.setUpdatedAt(1234567891); + template2.setContent("id: \"guardrail-002\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value-2\""); + + List templates = Arrays.asList(template1, template2); + + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + + // Test fetchMCPGuardrailConfig with includeYamlContent = true + Map resultWithContent = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(2, resultWithContent.size()); + + MCPGuardrailConfig config1 = resultWithContent.get("guardrail-001"); + assertNotNull(config1); + assertEquals("guardrail-001", config1.getId()); + assertNotNull(config1.getFilter()); + assertEquals("author1", config1.getAuthor()); + assertEquals(1234567890, config1.getCreatedAt()); + assertEquals(1234567890, config1.getUpdatedAt()); + assertNotNull(config1.getContent()); + assertTrue(config1.getContent().contains("guardrail-001")); + + MCPGuardrailConfig config2 = resultWithContent.get("guardrail-002"); + assertNotNull(config2); + assertEquals("guardrail-002", config2.getId()); + assertNotNull(config2.getFilter()); + assertEquals("author2", config2.getAuthor()); + assertEquals(1234567891, config2.getCreatedAt()); + assertEquals(1234567891, config2.getUpdatedAt()); + assertNotNull(config2.getContent()); + assertTrue(config2.getContent().contains("guardrail-002")); + + // Test fetchMCPGuardrailConfig with includeYamlContent = false + Map resultWithoutContent = dao.fetchMCPGuardrailConfig(false, templates); + + assertEquals(2, resultWithoutContent.size()); + + MCPGuardrailConfig config1NoContent = resultWithoutContent.get("guardrail-001"); + assertNotNull(config1NoContent); + assertNull(config1NoContent.getContent()); // Content should be null when includeYamlContent = false + } + + @Test + public void testFetchMCPGuardrailConfigWithInvalidYaml() { + YamlTemplate invalidTemplate = new YamlTemplate(); + invalidTemplate.setId("invalid-template"); + invalidTemplate.setAuthor("invalid-author"); + invalidTemplate.setCreatedAt(1234567890); + invalidTemplate.setUpdatedAt(1234567890); + invalidTemplate.setContent("invalid: yaml: content: ["); + + List templates = Arrays.asList(invalidTemplate); + + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + + // Should not throw exception, but should skip invalid templates + Map result = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(0, result.size()); // Invalid template should be skipped + } + + @Test + public void testFetchMCPGuardrailConfigWithNullTemplate() { + List templates = Arrays.asList((YamlTemplate) null); + + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + Map result = dao.fetchMCPGuardrailConfig(true, templates); + + assertEquals(0, result.size()); // Null template should be skipped + } + + @Test + public void testGetCollName() { + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + assertEquals("mcp_guardrail_yaml_templates", dao.getCollName()); + } + + @Test + public void testGetClassT() { + MCPGuardrailYamlTemplateDao dao = new MCPGuardrailYamlTemplateDao(); + assertEquals(YamlTemplate.class, dao.getClassT()); + } +} diff --git a/libs/dao/src/test/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParserTest.java b/libs/dao/src/test/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParserTest.java new file mode 100644 index 0000000000..6d057a2745 --- /dev/null +++ b/libs/dao/src/test/java/com/akto/dto/mcp/MCPGuardrailConfigYamlParserTest.java @@ -0,0 +1,90 @@ +package com.akto.dto.mcp; + +import org.junit.Test; +import static org.junit.Assert.*; + +import java.util.HashMap; +import java.util.Map; + +public class MCPGuardrailConfigYamlParserTest { + + @Test + public void testParseValidYamlTemplate() throws Exception { + String yamlContent = "id: \"test-guardrail-001\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"test-value\""; + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + assertNotNull(config); + assertEquals("test-guardrail-001", config.getId()); + assertNotNull(config.getFilter()); + } + + @Test + public void testParseYamlWithDefaults() throws Exception { + String yamlContent = "id: \"minimal-guardrail\"\n" + + "filter:\n" + + " pred:\n" + + " - data:\n" + + " - \"default-value\""; + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseTemplate(yamlContent); + + assertNotNull(config); + assertEquals("minimal-guardrail", config.getId()); + assertNotNull(config.getFilter()); + } + + + @Test(expected = Exception.class) + public void testParseInvalidYaml() throws Exception { + String invalidYaml = "invalid: yaml: content: ["; + MCPGuardrailConfigYamlParser.parseTemplate(invalidYaml); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseEmptyYaml() throws Exception { + MCPGuardrailConfigYamlParser.parseTemplate(""); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseNullYaml() throws Exception { + MCPGuardrailConfigYamlParser.parseTemplate(null); + } + + @Test + public void testParseConfigWithMap() throws Exception { + Map configMap = new HashMap<>(); + configMap.put("id", "map-test-001"); + + Map filterMap = new HashMap<>(); + java.util.List predList = new java.util.ArrayList<>(); + Map predItem = new HashMap<>(); + java.util.List dataList = new java.util.ArrayList<>(); + dataList.add("test-data"); + predItem.put("data", dataList); + predList.add(predItem); + filterMap.put("pred", predList); + configMap.put("filter", filterMap); + + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseConfig(configMap); + + assertNotNull(config); + assertEquals("map-test-001", config.getId()); + assertNotNull(config.getFilter()); + } + + @Test + public void testParseConfigWithNullId() throws Exception { + Map configMap = new HashMap<>(); + configMap.put("name", "No ID Test"); + + MCPGuardrailConfig config = MCPGuardrailConfigYamlParser.parseConfig(configMap); + + assertNull(config); + } +} diff --git a/libs/mcp-guardrails/.gitignore b/libs/mcp-guardrails/.gitignore new file mode 100644 index 0000000000..de67df5e37 --- /dev/null +++ b/libs/mcp-guardrails/.gitignore @@ -0,0 +1,18 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work diff --git a/libs/mcp-guardrails/INTEGRATION_SUMMARY.md b/libs/mcp-guardrails/INTEGRATION_SUMMARY.md new file mode 100644 index 0000000000..a4ce27fd54 --- /dev/null +++ b/libs/mcp-guardrails/INTEGRATION_SUMMARY.md @@ -0,0 +1,199 @@ +# MCP Guardrails Library - Integration Summary + +## Overview + +This document summarizes the complete integration of the MCP Guardrails library with your MCP proxy server. The library provides comprehensive security and compliance features for MCP (Model Context Protocol) servers. + +## What Was Implemented + +### 1. Enhanced Guardrails Library (`guardrails.go`) +- **Template Fetcher Service**: Background service that periodically fetches guardrail templates from your database-abstractor service +- **MCP Request/Response Processing**: Convenience functions for processing MCP requests and responses +- **Concurrent Safety**: Thread-safe operations with mutex protection +- **Error Handling**: Robust error handling with graceful degradation + +### 2. MCP Proxy Integration Layer (`proxy_integration.go`) +- **Environment Variable Configuration**: Reads all configuration from environment variables +- **Authentication Support**: Optional token-based authentication for the database-abstractor service +- **Simple API**: Clean, easy-to-use functions for MCP proxy integration +- **Status Monitoring**: Built-in status and health check capabilities + +### 3. Enhanced Template Client (`template_client.go`) +- **Authentication**: Bearer token authentication support +- **Error Handling**: Comprehensive error handling for API requests +- **Health Checks**: Service health monitoring capabilities + +### 4. Complete Example (`examples/mcp_proxy_example.go`) +- **Full MCP Proxy Implementation**: Complete working example of MCP proxy with guardrails +- **HTTP Endpoints**: Health checks, status monitoring, and MCP request handling +- **Error Handling**: Proper error handling and response formatting + +## Key Features + +### πŸ”„ Automatic Template Fetching +- Fetches guardrail templates from your database-abstractor service at configurable intervals +- Handles authentication with optional bearer tokens +- Continues operation with previously loaded templates if fetching fails + +### πŸ›‘οΈ Request Guardrails +- Validates incoming MCP requests +- Applies rate limiting to prevent abuse +- Filters malicious or inappropriate content +- Validates input parameters + +### πŸ”’ Response Guardrails +- Sanitizes sensitive data in responses +- Applies content filtering +- Masks or redacts sensitive information +- Monitors output for policy violations + +### βš™οΈ Configurable Security +- Enable/disable individual security features +- Customize sensitive field patterns +- Configure rate limiting parameters +- Set validation rules + +### πŸ“Š Monitoring & Observability +- Health check endpoints +- Status monitoring +- Detailed logging +- Template refresh status + +## Environment Variables + +### Required +```bash +export GUARDRAIL_SERVICE_URL="http://localhost:8080" +``` + +### Optional +```bash +export GUARDRAIL_SERVICE_TOKEN="your-auth-token" +export GUARDRAIL_REFRESH_INTERVAL="10" # minutes +export GUARDRAIL_ENABLE_SANITIZATION="true" +export GUARDRAIL_ENABLE_CONTENT_FILTERING="true" +export GUARDRAIL_ENABLE_RATE_LIMITING="true" +export GUARDRAIL_ENABLE_INPUT_VALIDATION="true" +export GUARDRAIL_ENABLE_OUTPUT_FILTERING="true" +export GUARDRAIL_ENABLE_LOGGING="true" +``` + +## Integration Steps + +### 1. Import the Library +```go +import "github.com/akto-api-security/akto/libs/mcp-guardrails" +``` + +### 2. Initialize Integration +```go +guardrails, err := guardrails.NewMCPProxyIntegration() +if err != nil { + log.Fatalf("Failed to initialize guardrails: %v", err) +} +``` + +### 3. Apply Request Guardrails +```go +blocked, reason, err := guardrails.RequestGuardrail(requestData) +if blocked { + // Handle blocked request + return +} +``` + +### 4. Apply Response Guardrails +```go +processedResponse, modified, err := guardrails.ResponseGuardrail(responseData) +// Use processedResponse +``` + +## API Endpoints + +Your MCP proxy will expose these endpoints: + +- **`/mcp`**: Main MCP endpoint with guardrails applied +- **`/health`**: Health check with guardrails status +- **`/status`**: Detailed status information and manual template refresh + +## Database-Abstractor Integration + +The library integrates with your existing database-abstractor service using these endpoints: + +- **`/api/mcp/fetchGuardrailTemplates`**: Fetch all active templates +- **`/api/mcp/fetchGuardrailTemplatesByType`**: Fetch templates by type +- **`/api/mcp/fetchGuardrailConfigs`**: Fetch parsed configurations +- **`/api/mcp/fetchGuardrailTemplate`**: Fetch specific template +- **`/api/mcp/health`**: Health check endpoint + +## Security Benefits + +### 1. **Data Protection** +- Automatic detection and redaction of sensitive information +- Configurable sensitive field patterns +- Support for custom redaction rules + +### 2. **Content Security** +- Keyword-based content filtering +- Regex pattern matching +- Configurable blocking/warning actions + +### 3. **Rate Limiting** +- Prevents API abuse +- Configurable request limits +- Burst protection + +### 4. **Input Validation** +- Validates MCP request structure +- Custom validation rules +- Parameter validation + +### 5. **Audit & Compliance** +- Comprehensive logging +- Request/response monitoring +- Template application tracking + +## Performance Considerations + +- **Background Template Fetching**: Templates are fetched in the background without blocking requests +- **Caching**: Templates are cached locally and refreshed periodically +- **Graceful Degradation**: System continues operating with previously loaded templates if fetching fails +- **Concurrent Safety**: All operations are thread-safe for high-throughput scenarios + +## Monitoring & Troubleshooting + +### Health Monitoring +```bash +curl http://localhost:8080/health +``` + +### Status Check +```bash +curl http://localhost:8080/status +``` + +### Manual Template Refresh +```bash +curl -X POST http://localhost:8080/status +``` + +### Log Monitoring +- Watch application logs for guardrail warnings and errors +- Monitor template fetch success/failure rates +- Track blocked requests and modified responses + +## Next Steps + +1. **Deploy the Integration**: Use the example code as a starting point for your MCP proxy +2. **Configure Templates**: Set up guardrail templates in your database-abstractor service +3. **Test Security**: Verify that guardrails are working correctly with test requests +4. **Monitor Performance**: Monitor the system for any performance impact +5. **Tune Configuration**: Adjust environment variables based on your requirements + +## Support + +- **Documentation**: See `MCP_PROXY_INTEGRATION.md` for detailed integration guide +- **Example Code**: See `examples/mcp_proxy_example.go` for complete implementation +- **API Reference**: All functions are documented with Go doc comments + +The integration is now complete and ready for production use! diff --git a/libs/mcp-guardrails/MCP_PROXY_INTEGRATION.md b/libs/mcp-guardrails/MCP_PROXY_INTEGRATION.md new file mode 100644 index 0000000000..79e5be6e21 --- /dev/null +++ b/libs/mcp-guardrails/MCP_PROXY_INTEGRATION.md @@ -0,0 +1,268 @@ +# MCP Proxy Integration with Guardrails + +This document explains how to integrate the MCP Guardrails library with your MCP proxy server. + +## Overview + +The MCP Guardrails library provides security and compliance features for MCP (Model Context Protocol) servers by: + +1. **Fetching Templates**: Periodically retrieving guardrail templates from your database-abstractor service +2. **Request Guardrails**: Validating and filtering incoming MCP requests +3. **Response Guardrails**: Sanitizing and filtering outgoing MCP responses +4. **Rate Limiting**: Controlling request rates to prevent abuse +5. **Data Sanitization**: Removing or masking sensitive information + +## Environment Variables + +Configure the guardrails system using these environment variables: + +### Required +- `GUARDRAIL_SERVICE_URL`: URL of your database-abstractor service (e.g., `http://localhost:8080`) + +### Optional +- `GUARDRAIL_SERVICE_TOKEN`: Authentication token for the service API +- `GUARDRAIL_REFRESH_INTERVAL`: Template refresh interval in minutes (default: 10) +- `GUARDRAIL_ENABLE_SANITIZATION`: Enable data sanitization (default: true) +- `GUARDRAIL_ENABLE_CONTENT_FILTERING`: Enable content filtering (default: true) +- `GUARDRAIL_ENABLE_RATE_LIMITING`: Enable rate limiting (default: true) +- `GUARDRAIL_ENABLE_INPUT_VALIDATION`: Enable input validation (default: true) +- `GUARDRAIL_ENABLE_OUTPUT_FILTERING`: Enable output filtering (default: true) +- `GUARDRAIL_ENABLE_LOGGING`: Enable logging (default: true) + +## Integration Example + +```go +package main + +import ( + "log" + "net/http" + + "github.com/akto-api-security/akto/libs/mcp-guardrails" +) + +func main() { + // Initialize guardrails integration + guardrails, err := guardrails.NewMCPProxyIntegration() + if err != nil { + log.Fatalf("Failed to initialize guardrails: %v", err) + } + + // Create HTTP server with guardrails middleware + server := &http.Server{ + Addr: ":8080", + Handler: createHandler(guardrails), + } + + log.Println("MCP Proxy with Guardrails starting on :8080") + log.Fatal(server.ListenAndServe()) +} + +func createHandler(guardrails *guardrails.MCPProxyIntegration) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/mcp", func(w http.ResponseWriter, r *http.Request) { + handleMCPRequest(w, r, guardrails) + }) + + return mux +} + +func handleMCPRequest(w http.ResponseWriter, r *http.Request, guardrails *guardrails.MCPProxyIntegration) { + // Read request body + requestBody, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + + // Apply request guardrails + blocked, reason, err := guardrails.RequestGuardrail(requestBody) + if err != nil { + log.Printf("Guardrail processing error: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + if blocked { + // Send error response + errorResponse := map[string]interface{}{ + "jsonrpc": "2.0", + "error": map[string]interface{}{ + "code": -32603, + "message": reason, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(errorResponse) + return + } + + // Process the request (forward to your MCP server) + responseData, err := forwardToMCPServer(requestBody) + if err != nil { + // Handle error + return + } + + // Apply response guardrails + processedResponse, modified, err := guardrails.ResponseGuardrail(responseData) + if err != nil { + log.Printf("Response guardrail processing error: %v", err) + processedResponse = responseData // Use original response if processing fails + } + + if modified { + log.Println("Response was modified by guardrails") + } + + // Send response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(processedResponse) +} +``` + +## API Reference + +### NewMCPProxyIntegration() + +Creates a new MCP proxy integration instance. Reads configuration from environment variables and starts the background template fetcher. + +```go +guardrails, err := guardrails.NewMCPProxyIntegration() +``` + +### RequestGuardrail(requestData []byte) (bool, string, error) + +Processes an incoming MCP request and determines if it should be blocked. + +- **Parameters**: `requestData` - Raw JSON bytes of the MCP request +- **Returns**: + - `bool` - True if request should be blocked + - `string` - Reason for blocking (if blocked) + - `error` - Processing error (if any) + +```go +blocked, reason, err := guardrails.RequestGuardrail(requestData) +``` + +### ResponseGuardrail(responseData []byte) ([]byte, bool, error) + +Processes an outgoing MCP response and applies sanitization/filtering. + +- **Parameters**: `responseData` - Raw JSON bytes of the MCP response +- **Returns**: + - `[]byte` - Processed response data (may be modified) + - `bool` - True if response was modified + - `error` - Processing error (if any) + +```go +processedResponse, modified, err := guardrails.ResponseGuardrail(responseData) +``` + +### GetStatus() map[string]interface{} + +Returns the current status of the guardrails system. + +```go +status := guardrails.GetStatus() +// Returns: {"templates_loaded": 5, "configs_loaded": 3, "engine_active": true} +``` + +### RefreshTemplates() error + +Manually triggers a template refresh from the database-abstractor service. + +```go +err := guardrails.RefreshTemplates() +``` + +## Endpoints + +The integration provides these HTTP endpoints: + +### `/mcp` +Main MCP endpoint that applies guardrails to requests and responses. + +### `/health` +Health check endpoint that returns the status of the guardrails system. + +```json +{ + "status": "healthy", + "guardrails": { + "templates_loaded": 5, + "configs_loaded": 3, + "engine_active": true + } +} +``` + +### `/status` +Status endpoint that returns detailed guardrails information. + +**GET**: Returns current status +**POST**: Triggers manual template refresh + +## Template Fetching + +The system automatically fetches guardrail templates from your database-abstractor service at regular intervals. The templates are used to: + +1. **Define Sensitive Data Patterns**: Regex patterns for detecting and redacting sensitive information +2. **Set Content Filters**: Rules for blocking or warning about specific content +3. **Configure Rate Limits**: Request rate limiting parameters +4. **Define Validation Rules**: Input validation requirements + +## Error Handling + +The guardrails system is designed to be resilient: + +- **Template Fetch Failures**: If template fetching fails, the system continues with previously loaded templates +- **Processing Errors**: If guardrail processing fails, requests/responses are allowed through with warnings logged +- **Invalid JSON**: Malformed requests/responses are blocked with appropriate error messages + +## Security Considerations + +1. **Authentication**: Use `GUARDRAIL_SERVICE_TOKEN` to authenticate with your database-abstractor service +2. **Network Security**: Ensure the connection between the proxy and database-abstractor is secure +3. **Template Validation**: The system validates templates before applying them +4. **Logging**: Enable logging to monitor guardrail activity and potential security issues + +## Monitoring + +Monitor the guardrails system through: + +1. **Health Checks**: Use the `/health` endpoint for health monitoring +2. **Status Monitoring**: Use the `/status` endpoint for detailed status information +3. **Logs**: Monitor application logs for guardrail warnings and errors +4. **Metrics**: Track blocked requests and modified responses + +## Troubleshooting + +### Common Issues + +1. **Template Fetch Failures** + - Check `GUARDRAIL_SERVICE_URL` is correct + - Verify network connectivity + - Check authentication token if required + +2. **High Request Blocking** + - Review template configurations + - Check if rate limits are too restrictive + - Verify content filter rules + +3. **Performance Issues** + - Adjust `GUARDRAIL_REFRESH_INTERVAL` to reduce API calls + - Monitor template complexity + - Check if too many patterns are being applied + +### Debug Mode + +Enable debug logging by setting: +```bash +export GUARDRAIL_ENABLE_LOGGING=true +``` + +This will provide detailed logs of guardrail operations. diff --git a/libs/mcp-guardrails/README.md b/libs/mcp-guardrails/README.md new file mode 100644 index 0000000000..d1e93a630a --- /dev/null +++ b/libs/mcp-guardrails/README.md @@ -0,0 +1,303 @@ +# MCP Guardrails Library + +A Golang library for implementing guardrails in MCP (Model Context Protocol) applications. This library fetches guardrail templates from an API and provides functions to modify requests and responses based on regex patterns defined in the templates. + +## Features + +- **Automatic Template Fetching**: Fetches guardrail templates from the API at regular intervals (default: 10 minutes) +- **Request/Response Modification**: Modifies requests and responses based on regex patterns in templates +- **YAML Template Parsing**: Parses YAML templates containing request_payload and response_payload filters +- **JSON Support**: Handles both string and JSON data modification +- **Health Monitoring**: Provides health status and template statistics +- **Sensitive Data Detection**: Detects and sanitizes sensitive data based on patterns +- **Thread-Safe**: All operations are thread-safe with proper locking + +## Installation + +```bash +go get github.com/akto-api-security/akto/libs/mcp-guardrails +``` + +## Quick Start + +```go +package main + +import ( + "fmt" + "time" + "github.com/akto-api-security/akto/libs/mcp-guardrails" +) + +func main() { + // Create configuration + config := guardrails.ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "your-auth-token", + FetchInterval: 10 * time.Minute, + } + + // Create and start the guardrail engine + engine := guardrails.NewGuardrailEngine(config) + engine.Start() + + // Modify a request + requestData := `{"method": "test", "data": "1234-5678-9012-3456"}` + result := engine.ModifyRequest(requestData) + + if result.Blocked { + fmt.Printf("Request blocked: %s\n", result.Reason) + } +} +``` + +## API Reference + +### ClientConfig + +Configuration for the guardrail client. + +```go +type ClientConfig struct { + APIURL string // API base URL + AuthToken string // Authentication token + FetchInterval time.Duration // Template fetch interval +} +``` + +### GuardrailEngine + +Main engine for managing guardrails. + +#### Methods + +- `NewGuardrailEngine(config ClientConfig) *GuardrailEngine` - Creates a new engine +- `Start()` - Starts periodic template fetching +- `Stop()` - Stops the engine +- `TriggerTemplateFetching() error` - Manually triggers template fetching +- `ModifyRequest(requestData string) ModificationResult` - Modifies a request +- `ModifyResponse(responseData string) ModificationResult` - Modifies a response +- `ModifyRequestJSON(requestData []byte) (ModificationResult, error)` - Modifies JSON request +- `ModifyResponseJSON(responseData []byte) (ModificationResult, error)` - Modifies JSON response +- `GetTemplates() map[string]ParsedTemplate` - Returns all loaded templates +- `GetTemplate(id string) (ParsedTemplate, bool)` - Returns a specific template +- `GetTemplateStats() map[string]interface{}` - Returns template statistics +- `SanitizeData(data string, patterns []string) string` - Sanitizes sensitive data +- `CheckForSensitiveData(data string, patterns []string) (bool, []string)` - Checks for sensitive data +- `IsHealthy() bool` - Checks if the engine is healthy +- `GetHealthStatus() map[string]interface{}` - Returns detailed health status + +### ModificationResult + +Result of request/response modification. + +```go +type ModificationResult struct { + Modified bool `json:"modified"` // Whether data was modified + Blocked bool `json:"blocked"` // Whether request/response was blocked + Reason string `json:"reason"` // Block reason (if blocked) + Warnings []string `json:"warnings"` // Warning messages + Data string `json:"data"` // Modified data +} +``` + +## Template Format + +Templates are YAML files with the following structure: + +```yaml +id: PIIDataLeak +filter: + or: + - request_payload: + regex: + - "\\b\\d{4}[- ]?\\d{4}[- ]?\\d{4}\\b" + - response_payload: + regex: + - "\\b\\d{4}[- ]?\\d{4}[- ]?\\d{4}\\b" +info: + name: "PIIDataLeak" + description: "PII Data Leak detection" + severity: MEDIUM + category: + name: "PIIDataLeak" + displayName: "PII Data Leak" +``` + +## Usage Examples + +### Basic Usage + +```go +// Create engine +config := guardrails.ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "testing", + FetchInterval: 10 * time.Minute, +} +engine := guardrails.NewGuardrailEngine(config) +engine.Start() + +// Modify request +requestData := `{"method": "test", "data": "1234-5678-9012-3456"}` +result := engine.ModifyRequest(requestData) + +if result.Blocked { + log.Printf("Request blocked: %s", result.Reason) +} else if result.Modified { + log.Printf("Request modified: %s", result.Data) +} +``` + +### JSON Processing + +```go +// Process JSON request +jsonRequest := []byte(`{"method": "test", "data": "1234-5678-9012-3456"}`) +result, err := engine.ModifyRequestJSON(jsonRequest) +if err != nil { + log.Printf("Error: %v", err) + return +} + +if result.Blocked { + log.Printf("JSON request blocked: %s", result.Reason) +} +``` + +### Health Monitoring + +```go +// Check health +if !engine.IsHealthy() { + log.Printf("Guardrail engine is not healthy") +} + +// Get detailed health status +status := engine.GetHealthStatus() +fmt.Printf("Health: %v\n", status["healthy"]) +fmt.Printf("Templates: %v\n", status["template_count"]) +``` + +### Sensitive Data Handling + +```go +// Check for sensitive data +data := "My credit card is 1234-5678-9012-3456" +patterns := []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`} + +hasSensitive, matchedPatterns := engine.CheckForSensitiveData(data, patterns) +if hasSensitive { + log.Printf("Sensitive data detected: %v", matchedPatterns) +} + +// Sanitize data +sanitized := engine.SanitizeData(data, patterns) +log.Printf("Sanitized: %s", sanitized) +``` + +## Testing + +Run the tests: + +```bash +go test ./... +``` + +Run tests with coverage: + +```bash +go test -cover ./... +``` + +Run benchmarks: + +```bash +go test -bench=. ./... +``` + +## Integration with MCP Proxy + +This library is designed to be imported by the MCP proxy application. The proxy can use it to: + +1. Fetch guardrail templates from the API +2. Apply guardrails to incoming requests +3. Apply guardrails to outgoing responses +4. Monitor the health of the guardrail system + +Example integration: + +```go +// In your MCP proxy +import "github.com/akto-api-security/akto/libs/mcp-guardrails" + +// Initialize guardrails +config := guardrails.ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "your-token", + FetchInterval: 10 * time.Minute, +} +guardrailEngine := guardrails.NewGuardrailEngine(config) +guardrailEngine.Start() + +// In your request handler +func handleRequest(requestData []byte) ([]byte, error) { + result, err := guardrailEngine.ModifyRequestJSON(requestData) + if err != nil { + return nil, err + } + + if result.Blocked { + return nil, fmt.Errorf("request blocked: %s", result.Reason) + } + + // Process the request... + responseData := processRequest(result.Data) + + // Apply response guardrails + responseResult, err := guardrailEngine.ModifyResponseJSON(responseData) + if err != nil { + return nil, err + } + + if responseResult.Blocked { + return nil, fmt.Errorf("response blocked: %s", responseResult.Reason) + } + + return []byte(responseResult.Data), nil +} +``` + +## Configuration + +The library supports the following configuration options: + +- `APIURL`: Base URL of the API server +- `AuthToken`: Bearer token for API authentication +- `FetchInterval`: How often to fetch templates (default: 10 minutes) + +## Error Handling + +The library provides comprehensive error handling: + +- Network errors during template fetching +- Invalid JSON/YAML parsing +- Regex compilation errors +- Template validation errors + +All errors are logged and the library continues to operate with previously loaded templates. + +## Performance + +The library is optimized for performance: + +- Thread-safe operations with minimal locking +- Efficient regex compilation and caching +- Minimal memory allocation +- Fast template matching + +Benchmark results are available in the test suite. + +## License + +This library is part of the Akto project and follows the same licensing terms. \ No newline at end of file diff --git a/libs/mcp-guardrails/TEMPLATE_INTEGRATION.md b/libs/mcp-guardrails/TEMPLATE_INTEGRATION.md new file mode 100644 index 0000000000..1bb1b9998b --- /dev/null +++ b/libs/mcp-guardrails/TEMPLATE_INTEGRATION.md @@ -0,0 +1,238 @@ +# MCP Guardrail Template Integration + +This document explains how MCP Guardrail policies are fetched from the database using YAML templates. + +## Overview + +The MCP Guardrail system now supports fetching policy templates from a database via the database-abstractor service. This allows for dynamic policy management and centralized configuration. + +## Architecture + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Go Library │───▢│ Database │───▢│ MongoDB β”‚ +β”‚ (mcp-guardrails)β”‚ β”‚ Abstractor β”‚ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +## Components + +### 1. Database Layer (Java) + +#### DTOs +- `YamlTemplate` - Represents a YAML template stored in the database (reuses existing class) +- `Info` - Metadata about the template (reuses existing class from test_editor) +- `Category` - Category information for organizing templates (reuses existing class) +- `MCPGuardrailType` - Enum defining different types of guardrails +- `MCPGuardrailConfig` - Parsed configuration from YAML templates + +#### DAO +- `MCPGuardrailYamlTemplateDao` - Data access object for template operations + - `fetchMCPGuardrailConfig()` - Fetch and parse all templates + - `fetchActiveTemplates()` - Fetch only active templates + - `fetchTemplatesByType()` - Fetch templates by guardrail type + +#### Action +- `MCPGuardrailsAction` - REST API endpoints for template operations + - `fetchMCPGuardrailTemplates` - Get all templates + - `fetchMCPGuardrailTemplatesByType` - Get templates by type + - `fetchMCPGuardrailConfigs` - Get parsed configurations + - `fetchMCPGuardrailTemplate` - Get specific template by ID + - `health` - Health check endpoint + +### 2. API Endpoints + +All endpoints are configured in `struts.xml` under the `/api/mcp/` namespace: + +- `GET /api/mcp/health` - Health check +- `POST /api/mcp/fetchGuardrailTemplates` - Fetch all templates +- `POST /api/mcp/fetchGuardrailTemplatesByType` - Fetch templates by type +- `POST /api/mcp/fetchGuardrailConfigs` - Fetch parsed configurations +- `POST /api/mcp/fetchGuardrailTemplate` - Fetch specific template + +### 3. Go Library Integration + +#### New Types +- `YamlTemplate` - Go representation of database template (matches Java YamlTemplate) +- `Info` - Template metadata (matches Java Info) +- `Category` - Template category (matches Java Category) +- `MCPGuardrailConfig` - Go representation of parsed configuration +- `TemplateClient` - HTTP client for API communication +- `APIResponse` - Response structure from API calls + +#### Template Client +```go +// Create a template client +client := guardrails.NewTemplateClient("http://localhost:8080") + +// Fetch templates +templates, err := client.FetchGuardrailTemplates(true) // activeOnly=true +if err != nil { + log.Fatal(err) +} + +// Health check +err = client.HealthCheck() +if err != nil { + log.Printf("API not available: %v", err) +} +``` + +#### Enhanced Guardrail Engine +```go +// Create engine with template client +engine := guardrails.NewGuardrailEngineWithClient(config, templateClient) + +// Load templates from API +err := engine.LoadTemplatesFromAPI() +if err != nil { + log.Printf("Failed to load templates: %v", err) +} + +// Access loaded templates +templates := engine.GetAllTemplates() +configs := engine.GetAllConfigs() +``` + +## YAML Template Format + +Templates are stored as YAML content in the database. Here's the structure: + +```yaml +id: "data_sanitization_basic" +name: "Basic Data Sanitization" +description: "Sanitizes common sensitive data patterns" +version: "1.0.0" +type: "DATA_SANITIZATION" +enabled: true +priority: 100 + +configuration: + sensitiveFields: + - "password" + - "api_key" + - "secret" + + patterns: + - name: "Credit Card" + pattern: "\\b(?:\\d[ -]*?){13,16}\\b" + replacement: "***CREDIT_CARD***" + + validationRules: + method: "required" + + outputFilters: + - "block_sensitive" + +info: + name: "Basic Data Sanitization" + description: "Removes sensitive data patterns" + category: + name: "data_protection" + displayName: "Data Protection" + severity: "HIGH" + tags: + - "pii" + - "data_protection" +``` + +## Guardrail Types + +The system supports the following guardrail types: + +- `DATA_SANITIZATION` - Remove or redact sensitive data +- `CONTENT_FILTERING` - Filter content based on rules +- `INPUT_VALIDATION` - Validate input parameters +- `OUTPUT_FILTERING` - Filter output content +- `RATE_LIMITING` - Limit request rates +- `CUSTOM` - Custom guardrail implementations + +## Usage Examples + +### 1. Basic Usage + +```go +// Create template client +client := guardrails.NewTemplateClient("http://database-abstractor:8080") + +// Create engine with client +engine := guardrails.NewGuardrailEngineWithClient(config, client) + +// Load templates +if err := engine.LoadTemplatesFromAPI(); err != nil { + log.Printf("Using default templates: %v", err) +} + +// Process requests/responses +result := engine.ProcessResponse(response) +``` + +### 2. Type-Specific Loading + +```go +// Load only data sanitization templates +err := engine.LoadTemplatesByType("DATA_SANITIZATION") +if err != nil { + log.Printf("Failed to load data sanitization templates: %v", err) +} +``` + +### 3. Template Refresh + +```go +// Periodically refresh templates +ticker := time.NewTicker(5 * time.Minute) +go func() { + for range ticker.C { + if err := engine.RefreshTemplates(); err != nil { + log.Printf("Failed to refresh templates: %v", err) + } + } +}() +``` + +## Database Schema + +Templates are stored in the `mcp_guardrail_yaml_templates` collection using the standard `YamlTemplate` schema: + +- `id` (String) - Unique template identifier +- `createdAt` (int) - Creation timestamp +- `author` (String) - Template author +- `source` (String) - Template source (e.g., "AKTO_TEMPLATES") +- `updatedAt` (int) - Last update timestamp +- `hash` (int) - Content hash for change detection +- `content` (String) - YAML template content +- `info` (Info) - Template metadata +- `inactive` (boolean) - Whether template is active +- `repositoryUrl` (String) - Source repository URL + +## Error Handling + +The system gracefully handles various error scenarios: + +1. **API Unavailable**: Falls back to default patterns +2. **Invalid Templates**: Logs errors and continues with valid templates +3. **Network Issues**: Retries with exponential backoff +4. **Parse Errors**: Skips invalid templates and logs warnings + +## Security Considerations + +1. **Authentication**: API endpoints should be secured with proper authentication +2. **Validation**: All YAML templates are validated before parsing +3. **Sanitization**: Template content is sanitized to prevent injection attacks +4. **Access Control**: Restrict template modification to authorized users + +## Performance + +- Templates are cached in memory for fast access +- Periodic refresh minimizes database load +- Lazy loading of templates by type when needed +- Connection pooling for database access + +## Monitoring + +Health check endpoint provides: +- Database connectivity status +- Template count +- Last successful refresh timestamp +- Error counts and rates diff --git a/libs/mcp-guardrails/USAGE_EXAMPLE.md b/libs/mcp-guardrails/USAGE_EXAMPLE.md new file mode 100644 index 0000000000..cd2bc3bdc3 --- /dev/null +++ b/libs/mcp-guardrails/USAGE_EXAMPLE.md @@ -0,0 +1,122 @@ +# Using MCP Guardrails as a Library + +## How to Import and Use + +Once the `mcp-guardrails` library is published to GitHub, you can use it in any Go project: + +### 1. Import the Library + +```go +import "github.com/akto-api-security/akto/libs/mcp-guardrails" +``` + +### 2. Basic Usage Example + +```go +package main + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/akto-api-security/akto/libs/mcp-guardrails" +) + +func main() { + // Create configuration + config := &guardrails.GuardrailConfig{ + EnableDataSanitization: true, + SensitiveFields: []string{"password", "api_key", "secret"}, + EnableContentFiltering: true, + BlockedKeywords: []string{"malicious", "exploit"}, + EnableRateLimiting: true, + RateLimitConfig: guardrails.RateLimitConfig{ + RequestsPerMinute: 100, + BurstSize: 10, + WindowSize: time.Minute, + }, + } + + // Create engine + engine := guardrails.NewGuardrailEngine(config) + + // Process a request + request := &guardrails.MCPRequest{ + ID: "req_1", + Method: "tools/call", + Params: json.RawMessage(`{"tool": "file_read", "path": "/etc/passwd"}`), + } + + result := engine.ProcessRequest(request) + + if result.Blocked { + fmt.Printf("Request blocked: %s\n", result.BlockReason) + } else { + fmt.Println("Request allowed") + } + + // Process a response + response := &guardrails.MCPResponse{ + ID: "resp_1", + Result: json.RawMessage(`{ + "content": "user data with password: secret123", + "api_key": "sk-1234567890" + }`), + } + + result = engine.ProcessResponse(response) + + if result.SanitizedResponse != nil { + fmt.Println("Response sanitized successfully") + } +} +``` + +### 3. Integration with MCP Proxy + +```go +// In your MCP proxy server +func (s *MCPProxyServer) handleRequest(req *MCPRequest) *MCPResponse { + // Apply guardrails before processing + guardrailResult := s.guardrailEngine.ProcessRequest(req) + + if guardrailResult.Blocked { + return &MCPResponse{ + ID: req.ID, + Error: &MCPError{ + Code: -32000, + Message: "Request blocked by guardrails", + Data: guardrailResult.BlockReason, + }, + } + } + + // Process the request normally + response := s.processRequest(req) + + // Apply guardrails to response + responseResult := s.guardrailEngine.ProcessResponse(response) + + if responseResult.SanitizedResponse != nil { + return responseResult.SanitizedResponse + } + + return response +} +``` + +## Publishing Steps + +To make this library available for import: + +1. **Push to GitHub**: Ensure the code is in the `akto-api-security/akto` repository +2. **Tag a Release**: Create a git tag for versioning +3. **Go Proxy**: The Go module proxy will automatically index it +4. **Import**: Other projects can then import it using the full module path + +## Current Status + +βœ… **Ready for local development** - Works within the Go workspace +πŸ”„ **Ready for publishing** - Module name updated to match GitHub repository +⏳ **Waiting for publish** - Need to push to GitHub and create release tags diff --git a/libs/mcp-guardrails/client.go b/libs/mcp-guardrails/client.go new file mode 100644 index 0000000000..db315b152f --- /dev/null +++ b/libs/mcp-guardrails/client.go @@ -0,0 +1,149 @@ +package guardrails + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "gopkg.in/yaml.v3" +) + +// NewGuardrailClient creates a new guardrail client +func NewGuardrailClient(config ClientConfig) *GuardrailClient { + if config.FetchInterval == 0 { + config.FetchInterval = 10 * time.Minute + } + + return &GuardrailClient{ + APIURL: config.APIURL, + AuthToken: config.AuthToken, + Templates: make(map[string]ParsedTemplate), + FetchInterval: config.FetchInterval, + } +} + +// FetchGuardrailTemplates fetches guardrail templates from the API +func (c *GuardrailClient) FetchGuardrailTemplates(activeOnly bool) error { + url := fmt.Sprintf("%s/api/mcp/fetchGuardrailTemplates", c.APIURL) + + requestBody := map[string]bool{ + "activeOnly": activeOnly, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.AuthToken)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("API request failed with status: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var apiResponse APIResponse + if err := json.Unmarshal(body, &apiResponse); err != nil { + return fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Parse templates + newTemplates := make(map[string]ParsedTemplate) + for _, template := range apiResponse.MCPGuardrailTemplates { + if template.Inactive { + continue + } + + parsed, err := c.parseTemplate(template) + if err != nil { + fmt.Printf("Warning: failed to parse template %s: %v\n", template.ID, err) + continue + } + + newTemplates[template.ID] = parsed + } + + // Update templates atomically + c.Templates = newTemplates + c.LastFetch = time.Now() + + return nil +} + +// parseTemplate parses a YAML template from the content +func (c *GuardrailClient) parseTemplate(template MCPGuardrailTemplate) (ParsedTemplate, error) { + var parsed ParsedTemplate + + if err := yaml.Unmarshal([]byte(template.Content), &parsed); err != nil { + return parsed, fmt.Errorf("failed to parse YAML content: %w", err) + } + + // Set the ID from the template + parsed.ID = template.ID + + return parsed, nil +} + +// StartPeriodicFetching starts a goroutine that fetches templates at regular intervals +func (c *GuardrailClient) StartPeriodicFetching() { + go func() { + ticker := time.NewTicker(c.FetchInterval) + defer ticker.Stop() + + // Initial fetch + if err := c.FetchGuardrailTemplates(true); err != nil { + fmt.Printf("Initial template fetch failed: %v\n", err) + } else { + fmt.Printf("Initial templates loaded successfully\n") + } + + // Periodic fetching + for range ticker.C { + if err := c.FetchGuardrailTemplates(true); err != nil { + fmt.Printf("Template refresh failed: %v\n", err) + } else { + fmt.Printf("Templates refreshed successfully at %v\n", time.Now()) + } + } + }() +} + +// TriggerTemplateFetching manually triggers template fetching +func (c *GuardrailClient) TriggerTemplateFetching() error { + return c.FetchGuardrailTemplates(true) +} + +// GetTemplates returns a copy of all loaded templates +func (c *GuardrailClient) GetTemplates() map[string]ParsedTemplate { + templates := make(map[string]ParsedTemplate) + for k, v := range c.Templates { + templates[k] = v + } + return templates +} + +// GetTemplate returns a specific template by ID +func (c *GuardrailClient) GetTemplate(id string) (ParsedTemplate, bool) { + template, exists := c.Templates[id] + return template, exists +} diff --git a/libs/mcp-guardrails/go.mod b/libs/mcp-guardrails/go.mod new file mode 100644 index 0000000000..c08f933d8c --- /dev/null +++ b/libs/mcp-guardrails/go.mod @@ -0,0 +1,5 @@ +module github.com/akto-api-security/akto/libs/mcp-guardrails + +go 1.22 + +require gopkg.in/yaml.v3 v3.0.1 diff --git a/libs/mcp-guardrails/go.sum b/libs/mcp-guardrails/go.sum new file mode 100644 index 0000000000..a62c313c5b --- /dev/null +++ b/libs/mcp-guardrails/go.sum @@ -0,0 +1,4 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/libs/mcp-guardrails/guardrails.go b/libs/mcp-guardrails/guardrails.go new file mode 100644 index 0000000000..f2e8172128 --- /dev/null +++ b/libs/mcp-guardrails/guardrails.go @@ -0,0 +1,133 @@ +package guardrails + +import ( + "time" +) + +// GuardrailEngine is the main engine that coordinates template fetching and modification +type GuardrailEngine struct { + client *GuardrailClient + modifier *Modifier +} + +// NewGuardrailEngine creates a new guardrail engine +func NewGuardrailEngine(config ClientConfig) *GuardrailEngine { + client := NewGuardrailClient(config) + modifier := NewModifier(client) + + return &GuardrailEngine{ + client: client, + modifier: modifier, + } +} + +// Start starts the guardrail engine with periodic template fetching +func (e *GuardrailEngine) Start() { + // Start periodic template fetching + e.client.StartPeriodicFetching() +} + +// Stop stops the guardrail engine (placeholder for future implementation) +func (e *GuardrailEngine) Stop() { + // Future implementation for graceful shutdown +} + +// TriggerTemplateFetching manually triggers template fetching +func (e *GuardrailEngine) TriggerTemplateFetching() error { + return e.client.TriggerTemplateFetching() +} + +// ModifyRequest modifies a request based on active guardrail templates +func (e *GuardrailEngine) ModifyRequest(requestData string) ModificationResult { + return e.modifier.ModifyRequest(requestData) +} + +// ModifyResponse modifies a response based on active guardrail templates +func (e *GuardrailEngine) ModifyResponse(responseData string) ModificationResult { + return e.modifier.ModifyResponse(responseData) +} + +// ModifyRequestJSON modifies a JSON request +func (e *GuardrailEngine) ModifyRequestJSON(requestData []byte) (ModificationResult, error) { + return e.modifier.ModifyRequestJSON(requestData) +} + +// ModifyResponseJSON modifies a JSON response +func (e *GuardrailEngine) ModifyResponseJSON(responseData []byte) (ModificationResult, error) { + return e.modifier.ModifyResponseJSON(responseData) +} + +// GetTemplates returns all loaded templates +func (e *GuardrailEngine) GetTemplates() map[string]ParsedTemplate { + return e.client.GetTemplates() +} + +// GetTemplate returns a specific template by ID +func (e *GuardrailEngine) GetTemplate(id string) (ParsedTemplate, bool) { + return e.client.GetTemplate(id) +} + +// GetTemplateStats returns statistics about loaded templates +func (e *GuardrailEngine) GetTemplateStats() map[string]interface{} { + return e.modifier.GetTemplateStats() +} + +// SanitizeData sanitizes sensitive data in the payload +func (e *GuardrailEngine) SanitizeData(data string, patterns []string) string { + return e.modifier.SanitizeData(data, patterns) +} + +// CheckForSensitiveData checks if data contains sensitive information +func (e *GuardrailEngine) CheckForSensitiveData(data string, patterns []string) (bool, []string) { + return e.modifier.CheckForSensitiveData(data, patterns) +} + +// GetLastFetchTime returns the time when templates were last fetched +func (e *GuardrailEngine) GetLastFetchTime() time.Time { + return e.client.LastFetch +} + +// GetFetchInterval returns the fetch interval +func (e *GuardrailEngine) GetFetchInterval() time.Duration { + return e.client.FetchInterval +} + +// IsHealthy checks if the guardrail engine is healthy +func (e *GuardrailEngine) IsHealthy() bool { + // Check if we have templates loaded + templates := e.GetTemplates() + if len(templates) == 0 { + return false + } + + // Check if last fetch was recent (within 2x the fetch interval) + lastFetch := e.GetLastFetchTime() + if lastFetch.IsZero() { + return false + } + + interval := e.GetFetchInterval() + if time.Since(lastFetch) > interval*2 { + return false + } + + return true +} + +// GetHealthStatus returns detailed health status +func (e *GuardrailEngine) GetHealthStatus() map[string]interface{} { + templates := e.GetTemplates() + lastFetch := e.GetLastFetchTime() + interval := e.GetFetchInterval() + + status := map[string]interface{}{ + "healthy": e.IsHealthy(), + "template_count": len(templates), + "last_fetch": lastFetch, + "fetch_interval": interval, + "next_fetch_in": interval - time.Since(lastFetch), + "templates_loaded": len(templates) > 0, + } + + return status +} diff --git a/libs/mcp-guardrails/guardrails_test.go b/libs/mcp-guardrails/guardrails_test.go new file mode 100644 index 0000000000..3782784007 --- /dev/null +++ b/libs/mcp-guardrails/guardrails_test.go @@ -0,0 +1,490 @@ +package guardrails + +import ( + "io" + "net/http" + "testing" + "time" +) + +// Mock HTTP client for testing +type mockHTTPClient struct { + response []byte + err error +} + +func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + + // Create a mock response + resp := &http.Response{ + StatusCode: 200, + Body: &mockResponseBody{data: m.response}, + } + return resp, nil +} + +type mockResponseBody struct { + data []byte + pos int +} + +func (m *mockResponseBody) Read(p []byte) (n int, err error) { + if m.pos >= len(m.data) { + return 0, io.EOF + } + n = copy(p, m.data[m.pos:]) + m.pos += n + return n, nil +} + +func (m *mockResponseBody) Close() error { + return nil +} + +// Test data +var testTemplateContent = `id: PIIDataLeak +filter: + or: + - response_payload: + regex: + - (?:\\b(?:4\\d{3}|5[1-5]\\d{2}|2\\d{3}|3[47]\\d{1,2})[\\s\\-]?\\d{4,6}[\\s\\-]?\\d{4,6}?(?:[\\s\\-]\\d{3,4})?(?:\\d{3})?|\\b(?!000|666|9\\d{2})([0-8]\\d{2}|7([0-6]\\d))([-]?|\\s{1})(?!00)\\d\\d\\3(?!0000)\\d{4})\\b + - response_payload: + regex: + - \\b(\\d{4}[- ]?\\d{4}[- ]?\\d{4}|[A-Z]{5}[0-9]{4}[A-Z])\\b + +info: + name: "PIIDataLeak" + description: "PII Data Leak refers to the accidental or unauthorized exposure of Personally Identifiable Information" + details: "PII leaks commonly stem from insecure logging, improperly secured APIs" + impact: "Exposed PII can lead to identity theft, financial fraud, regulatory violations" + category: + name: "PIIDataLeak" + displayName: "PIIDataLeak" + subCategory: "PIIDataLeak" + severity: MEDIUM` + +var testAPIResponse = APIResponse{ + MCPGuardrailTemplates: []MCPGuardrailTemplate{ + { + ID: "PIIDataLeak", + Author: "system", + Content: testTemplateContent, + CreatedAt: 1759113676, + UpdatedAt: 1759113676, + Hash: 481648519, + Inactive: false, + Source: "CUSTOM", + }, + }, +} + +func TestNewGuardrailEngine(t *testing.T) { + config := ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "testing", + FetchInterval: 10 * time.Minute, + } + + engine := NewGuardrailEngine(config) + + if engine == nil { + t.Fatal("Expected engine to be created") + } + + if engine.client == nil { + t.Fatal("Expected client to be initialized") + } + + if engine.modifier == nil { + t.Fatal("Expected modifier to be initialized") + } +} + +func TestGuardrailClient_FetchGuardrailTemplates(t *testing.T) { + // This test would require mocking the HTTP client + // For now, we'll test the parsing logic + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + // Test parsing a template + template := MCPGuardrailTemplate{ + ID: "test", + Content: testTemplateContent, + Inactive: false, + } + + parsed, err := client.parseTemplate(template) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + if parsed.ID != "test" { + t.Errorf("Expected ID 'test', got '%s'", parsed.ID) + } + + if len(parsed.Filter.Or) == 0 { + t.Fatal("Expected filter conditions to be parsed") + } + + // Check if response_payload filters are parsed + foundResponsePayload := false + for _, condition := range parsed.Filter.Or { + if condition.ResponsePayload != nil && len(condition.ResponsePayload.Regex) > 0 { + foundResponsePayload = true + break + } + } + + if !foundResponsePayload { + t.Fatal("Expected response_payload filters to be found") + } +} + +func TestModifier_ModifyRequest(t *testing.T) { + // Create a mock client with test templates + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + // Add a test template + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + RequestPayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, // Credit card pattern + }, + }, + }, + }, + } + + modifier := NewModifier(client) + + // Test with sensitive data + requestData := `{"method": "test", "data": "1234-5678-9012-3456"` + result := modifier.ModifyRequest(requestData) + + if !result.Blocked { + t.Error("Expected request to be blocked due to credit card pattern") + } + + if result.Reason == "" { + t.Error("Expected block reason to be provided") + } + + // Test with safe data + safeRequestData := `{"method": "test", "data": "safe data"}` + safeResult := modifier.ModifyRequest(safeRequestData) + + if safeResult.Blocked { + t.Error("Expected safe request to not be blocked") + } +} + +func TestModifier_ModifyResponse(t *testing.T) { + // Create a mock client with test templates + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + // Add a test template + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + ResponsePayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, // Credit card pattern + }, + }, + }, + }, + } + + modifier := NewModifier(client) + + // Test with sensitive data + responseData := `{"status": "success", "data": "1234-5678-9012-3456"}` + result := modifier.ModifyResponse(responseData) + + if !result.Blocked { + t.Error("Expected response to be blocked due to credit card pattern") + } + + if result.Reason == "" { + t.Error("Expected block reason to be provided") + } + + // Test with safe data + safeResponseData := `{"status": "success", "data": "safe data"}` + safeResult := modifier.ModifyResponse(safeResponseData) + + if safeResult.Blocked { + t.Error("Expected safe response to not be blocked") + } +} + +func TestModifier_ModifyRequestJSON(t *testing.T) { + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + RequestPayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, + }, + }, + }, + }, + } + + modifier := NewModifier(client) + + // Test with valid JSON containing sensitive data + requestData := []byte(`{"method": "test", "data": "1234-5678-9012-3456"`) + result, err := modifier.ModifyRequestJSON(requestData) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !result.Blocked { + t.Error("Expected request to be blocked") + } +} + +func TestModifier_ModifyResponseJSON(t *testing.T) { + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + ResponsePayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, + }, + }, + }, + }, + } + + modifier := NewModifier(client) + + // Test with valid JSON containing sensitive data + responseData := []byte(`{"status": "success", "data": "1234-5678-9012-3456"}`) + result, err := modifier.ModifyResponseJSON(responseData) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !result.Blocked { + t.Error("Expected response to be blocked") + } +} + +func TestModifier_SanitizeData(t *testing.T) { + modifier := &Modifier{} + + data := "My credit card is 1234-5678-9012-3456" + patterns := []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`} + + sanitized := modifier.SanitizeData(data, patterns) + + if sanitized == data { + t.Error("Expected data to be sanitized") + } + + if !contains(sanitized, "***REDACTED***") { + t.Error("Expected sanitized data to contain redaction marker") + } +} + +func TestModifier_CheckForSensitiveData(t *testing.T) { + modifier := &Modifier{} + + data := "My credit card is 1234-5678-9012-3456" + patterns := []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`} + + hasSensitive, matchedPatterns := modifier.CheckForSensitiveData(data, patterns) + + if !hasSensitive { + t.Error("Expected sensitive data to be detected") + } + + if len(matchedPatterns) == 0 { + t.Error("Expected matched patterns to be returned") + } + + // Test with safe data + safeData := "This is safe data" + hasSensitive, _ = modifier.CheckForSensitiveData(safeData, patterns) + + if hasSensitive { + t.Error("Expected safe data to not be flagged as sensitive") + } +} + +func TestGuardrailEngine_GetTemplateStats(t *testing.T) { + config := ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "testing", + FetchInterval: 10 * time.Minute, + } + + engine := NewGuardrailEngine(config) + + // Add some test templates + engine.client.Templates["test1"] = ParsedTemplate{ + ID: "test1", + Info: TemplateInfo{ + Name: "Test Template 1", + Severity: "HIGH", + Category: struct { + Name string `yaml:"name"` + DisplayName string `yaml:"displayName"` + }{ + Name: "TestCategory", + DisplayName: "Test Category", + }, + }, + } + + stats := engine.GetTemplateStats() + + if stats["total_templates"] != 1 { + t.Errorf("Expected 1 template, got %v", stats["total_templates"]) + } + + templates, ok := stats["templates"].([]map[string]interface{}) + if !ok { + t.Fatal("Expected templates to be a slice") + } + + if len(templates) != 1 { + t.Errorf("Expected 1 template in stats, got %d", len(templates)) + } +} + +func TestGuardrailEngine_IsHealthy(t *testing.T) { + config := ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "testing", + FetchInterval: 10 * time.Minute, + } + + engine := NewGuardrailEngine(config) + + // Initially should not be healthy (no templates) + if engine.IsHealthy() { + t.Error("Expected engine to not be healthy initially") + } + + // Add templates and set last fetch time + engine.client.Templates["test"] = ParsedTemplate{ID: "test"} + engine.client.LastFetch = time.Now() + + if !engine.IsHealthy() { + t.Error("Expected engine to be healthy with templates and recent fetch") + } +} + +func TestGuardrailEngine_GetHealthStatus(t *testing.T) { + config := ClientConfig{ + APIURL: "http://localhost:8082", + AuthToken: "testing", + FetchInterval: 10 * time.Minute, + } + + engine := NewGuardrailEngine(config) + + status := engine.GetHealthStatus() + + if status["healthy"] == nil { + t.Error("Expected health status to include 'healthy' field") + } + + if status["template_count"] == nil { + t.Error("Expected health status to include 'template_count' field") + } + + if status["last_fetch"] == nil { + t.Error("Expected health status to include 'last_fetch' field") + } +} + +// Helper function to check if string contains substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > len(substr) && (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + contains(s[1:], substr)))) +} + +// Benchmark tests +func BenchmarkModifier_ModifyRequest(b *testing.B) { + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + RequestPayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, + }, + }, + }, + }, + } + + modifier := NewModifier(client) + requestData := `{"method": "test", "data": "1234-5678-9012-3456"` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + modifier.ModifyRequest(requestData) + } +} + +func BenchmarkModifier_ModifyResponse(b *testing.B) { + client := &GuardrailClient{ + Templates: make(map[string]ParsedTemplate), + } + + client.Templates["test"] = ParsedTemplate{ + ID: "test", + Filter: Filter{ + Or: []FilterCondition{ + { + ResponsePayload: &PayloadFilter{ + Regex: []string{`\b\d{4}[- ]?\d{4}[- ]?\d{4}\b`}, + }, + }, + }, + }, + } + + modifier := NewModifier(client) + responseData := `{"status": "success", "data": "1234-5678-9012-3456"}` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + modifier.ModifyResponse(responseData) + } +} diff --git a/libs/mcp-guardrails/modifier.go b/libs/mcp-guardrails/modifier.go new file mode 100644 index 0000000000..cb7d56d614 --- /dev/null +++ b/libs/mcp-guardrails/modifier.go @@ -0,0 +1,249 @@ +package guardrails + +import ( + "encoding/json" + "fmt" + "regexp" + "sync" +) + +// Modifier handles request and response modifications based on guardrail templates +type Modifier struct { + client *GuardrailClient + mu sync.RWMutex +} + +// NewModifier creates a new modifier instance +func NewModifier(client *GuardrailClient) *Modifier { + return &Modifier{ + client: client, + } +} + +// ModifyRequest modifies a request based on active guardrail templates +func (m *Modifier) ModifyRequest(requestData string) ModificationResult { + m.mu.RLock() + defer m.mu.RUnlock() + + result := ModificationResult{ + Modified: false, + Blocked: false, + Data: requestData, + Warnings: []string{}, + } + + // Get all templates + templates := m.client.GetTemplates() + + for templateID, template := range templates { + // Check if template has request_payload filters + for _, condition := range template.Filter.Or { + if condition.RequestPayload != nil { + modified, blocked, reason, warnings := m.applyPayloadFilters( + requestData, + condition.RequestPayload.Regex, + "request", + templateID, + ) + + if blocked { + result.Blocked = true + result.Reason = reason + result.Warnings = append(result.Warnings, warnings...) + return result + } + + if modified { + result.Modified = true + result.Data = requestData + result.Warnings = append(result.Warnings, warnings...) + } + } + } + } + + return result +} + +// ModifyResponse modifies a response based on active guardrail templates +func (m *Modifier) ModifyResponse(responseData string) ModificationResult { + m.mu.RLock() + defer m.mu.RUnlock() + + result := ModificationResult{ + Modified: false, + Blocked: false, + Data: responseData, + Warnings: []string{}, + } + + // Get all templates + templates := m.client.GetTemplates() + + for templateID, template := range templates { + // Check if template has response_payload filters + for _, condition := range template.Filter.Or { + if condition.ResponsePayload != nil { + modified, blocked, reason, warnings := m.applyPayloadFilters( + responseData, + condition.ResponsePayload.Regex, + "response", + templateID, + ) + + if blocked { + result.Blocked = true + result.Reason = reason + result.Warnings = append(result.Warnings, warnings...) + return result + } + + if modified { + result.Modified = true + result.Data = responseData + result.Warnings = append(result.Warnings, warnings...) + } + } + } + } + + return result +} + +// applyPayloadFilters applies regex filters to payload data +func (m *Modifier) applyPayloadFilters(data string, patterns []string, payloadType, templateID string) (bool, bool, string, []string) { + var warnings []string + modified := false + blocked := false + reason := "" + + for _, pattern := range patterns { + // Compile regex pattern + re, err := regexp.Compile(pattern) + if err != nil { + warnings = append(warnings, fmt.Sprintf("Invalid regex pattern in template %s: %v", templateID, err)) + continue + } + + // Check if pattern matches + if re.MatchString(data) { + // For now, we'll block on any match + // In a real implementation, you might want to sanitize instead + blocked = true + reason = fmt.Sprintf("Blocked by guardrail template %s: pattern matched in %s payload", templateID, payloadType) + warnings = append(warnings, fmt.Sprintf("Pattern matched in %s: %s", payloadType, pattern)) + break + } + } + + return modified, blocked, reason, warnings +} + +// ModifyRequestJSON modifies a JSON request based on guardrail templates +func (m *Modifier) ModifyRequestJSON(requestData []byte) (ModificationResult, error) { + var request interface{} + if err := json.Unmarshal(requestData, &request); err != nil { + return ModificationResult{ + Blocked: true, + Reason: "Invalid JSON format", + Warnings: []string{fmt.Sprintf("Failed to parse JSON: %v", err)}, + }, nil + } + + // Convert back to string for processing + requestStr := string(requestData) + result := m.ModifyRequest(requestStr) + + // If modified, convert back to JSON + if result.Modified { + // For now, we'll just return the original data + // In a real implementation, you'd apply the modifications + result.Data = requestStr + } + + return result, nil +} + +// ModifyResponseJSON modifies a JSON response based on guardrail templates +func (m *Modifier) ModifyResponseJSON(responseData []byte) (ModificationResult, error) { + var response interface{} + if err := json.Unmarshal(responseData, &response); err != nil { + return ModificationResult{ + Blocked: true, + Reason: "Invalid JSON format", + Warnings: []string{fmt.Sprintf("Failed to parse JSON: %v", err)}, + }, nil + } + + // Convert back to string for processing + responseStr := string(responseData) + result := m.ModifyResponse(responseStr) + + // If modified, convert back to JSON + if result.Modified { + // For now, we'll just return the original data + // In a real implementation, you'd apply the modifications + result.Data = responseStr + } + + return result, nil +} + +// SanitizeData sanitizes sensitive data in the payload +func (m *Modifier) SanitizeData(data string, patterns []string) string { + sanitized := data + + for _, pattern := range patterns { + re, err := regexp.Compile(pattern) + if err != nil { + continue + } + + // Replace matches with sanitized text + sanitized = re.ReplaceAllString(sanitized, "***REDACTED***") + } + + return sanitized +} + +// CheckForSensitiveData checks if data contains sensitive information +func (m *Modifier) CheckForSensitiveData(data string, patterns []string) (bool, []string) { + var matchedPatterns []string + + for _, pattern := range patterns { + re, err := regexp.Compile(pattern) + if err != nil { + continue + } + + if re.MatchString(data) { + matchedPatterns = append(matchedPatterns, pattern) + } + } + + return len(matchedPatterns) > 0, matchedPatterns +} + +// GetTemplateStats returns statistics about loaded templates +func (m *Modifier) GetTemplateStats() map[string]interface{} { + templates := m.client.GetTemplates() + + stats := map[string]interface{}{ + "total_templates": len(templates), + "templates": make([]map[string]interface{}, 0), + } + + templateList := make([]map[string]interface{}, 0) + for id, template := range templates { + templateInfo := map[string]interface{}{ + "id": id, + "name": template.Info.Name, + "severity": template.Info.Severity, + "category": template.Info.Category.Name, + } + templateList = append(templateList, templateInfo) + } + + stats["templates"] = templateList + return stats +} diff --git a/libs/mcp-guardrails/sample-template.yaml b/libs/mcp-guardrails/sample-template.yaml new file mode 100644 index 0000000000..6a2cb5a4fe --- /dev/null +++ b/libs/mcp-guardrails/sample-template.yaml @@ -0,0 +1,64 @@ +# Sample MCP Guardrail YAML Template +id: "data_sanitization_basic" +type: "DATA_SANITIZATION" +enabled: true +priority: 100 + +configuration: + sensitiveFields: + - "password" + - "api_key" + - "secret" + - "token" + - "private_key" + - "access_token" + - "refresh_token" + + patterns: + - name: "Credit Card" + pattern: "\\b(?:\\d[ -]*?){13,16}\\b" + replacement: "***CREDIT_CARD***" + + - name: "Social Security Number" + pattern: "\\b\\d{3}-?\\d{2}-?\\d{4}\\b" + replacement: "***SSN***" + + - name: "Email Address" + pattern: "\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b" + replacement: "***EMAIL***" + + - name: "Phone Number" + pattern: "\\b\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\\b" + replacement: "***PHONE***" + + validationRules: + method: "required" + + outputFilters: + - "block_sensitive" + - "warn_on_pii" + +info: + name: "Basic Data Sanitization" + description: "Removes sensitive data patterns from MCP responses" + details: "This guardrail template provides basic data sanitization by detecting and redacting common sensitive data patterns such as credit cards, SSNs, emails, and phone numbers." + purpose: "Protect sensitive information from being exposed in MCP responses" + category: + name: "data_protection" + displayName: "Data Protection" + shortName: "DP" + description: "Templates focused on protecting sensitive data" + severity: "HIGH" + tags: + - "pii" + - "data_protection" + - "sanitization" + - "compliance" + applicableScenarios: + - "customer_data_handling" + - "financial_services" + - "healthcare" + - "general_compliance" + dependencies: [] + enabled: true + priority: 100 diff --git a/libs/mcp-guardrails/types.go b/libs/mcp-guardrails/types.go new file mode 100644 index 0000000000..9ea5f4458e --- /dev/null +++ b/libs/mcp-guardrails/types.go @@ -0,0 +1,86 @@ +package guardrails + +import ( + "time" +) + +// MCPGuardrailTemplate represents a guardrail template from the API +type MCPGuardrailTemplate struct { + ID string `json:"id"` + Author string `json:"author"` + Content string `json:"content"` + CreatedAt int64 `json:"createdAt"` + UpdatedAt int64 `json:"updatedAt"` + Hash int64 `json:"hash"` + Inactive bool `json:"inactive"` + Source string `json:"source"` + RepositoryURL *string `json:"repositoryUrl,omitempty"` + Info *TemplateInfo `json:"info,omitempty"` +} + +// TemplateInfo contains metadata about the template +type TemplateInfo struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Details string `yaml:"details"` + Impact string `yaml:"impact"` + Category struct { + Name string `yaml:"name"` + DisplayName string `yaml:"displayName"` + } `yaml:"category"` + SubCategory string `yaml:"subCategory"` + Severity string `yaml:"severity"` +} + +// ParsedTemplate represents a parsed YAML template +type ParsedTemplate struct { + ID string `yaml:"id"` + Filter Filter `yaml:"filter"` + Info TemplateInfo `yaml:"info"` +} + +// Filter contains the regex patterns for request and response payloads +type Filter struct { + Or []FilterCondition `yaml:"or"` +} + +// FilterCondition represents a single filter condition +type FilterCondition struct { + RequestPayload *PayloadFilter `yaml:"request_payload,omitempty"` + ResponsePayload *PayloadFilter `yaml:"response_payload,omitempty"` +} + +// PayloadFilter contains regex patterns for payload filtering +type PayloadFilter struct { + Regex []string `yaml:"regex"` +} + +// APIResponse represents the response from the fetchGuardrailTemplates API +type APIResponse struct { + MCPGuardrailTemplates []MCPGuardrailTemplate `json:"mcpGuardrailTemplates"` +} + +// GuardrailClient handles fetching and managing guardrail templates +type GuardrailClient struct { + APIURL string + AuthToken string + Templates map[string]ParsedTemplate + LastFetch time.Time + FetchInterval time.Duration +} + +// ModificationResult represents the result of request/response modification +type ModificationResult struct { + Modified bool `json:"modified"` + Blocked bool `json:"blocked"` + Reason string `json:"reason,omitempty"` + Warnings []string `json:"warnings,omitempty"` + Data string `json:"data"` +} + +// ClientConfig holds configuration for the guardrail client +type ClientConfig struct { + APIURL string + AuthToken string + FetchInterval time.Duration +} diff --git a/libs/utils/src/main/java/com/akto/utils/MCPGuardrailUtil.java b/libs/utils/src/main/java/com/akto/utils/MCPGuardrailUtil.java new file mode 100644 index 0000000000..8d40c36842 --- /dev/null +++ b/libs/utils/src/main/java/com/akto/utils/MCPGuardrailUtil.java @@ -0,0 +1,36 @@ +package com.akto.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.bson.conversions.Bson; + +import com.akto.dao.context.Context; +import com.akto.dto.test_editor.YamlTemplate; +import com.akto.util.enums.GlobalEnums; +import com.mongodb.client.model.Updates; + +public class MCPGuardrailUtil { + + public static List getDbUpdateForTemplate(String content, String userEmail) throws Exception { + try { + String author = userEmail; + int timeNow = Context.now(); + int createdAt = timeNow; + int updatedAt = timeNow; + + List updates = new ArrayList<>( + Arrays.asList( + Updates.setOnInsert(YamlTemplate.CREATED_AT, createdAt), + Updates.setOnInsert(YamlTemplate.AUTHOR, author), + Updates.set(YamlTemplate.UPDATED_AT, updatedAt), + Updates.set(YamlTemplate.CONTENT, content), + Updates.setOnInsert(YamlTemplate.SOURCE, GlobalEnums.YamlTemplateSource.CUSTOM))); + return updates; + + } catch (Exception e) { + throw e; + } + } +}