diff --git a/internal/mcp/server.go b/internal/mcp/server.go index e1a0d77..288a096 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -189,8 +189,8 @@ type RPCError struct { // ListConventionInput represents the input schema for the list_convention tool (go-sdk). type ListConventionInput struct { - Category string `json:"category,omitempty" jsonschema:"Filter by category (optional). Use 'all' or leave empty to fetch all categories. Options: security, style, documentation, error_handling, architecture, performance, testing"` - Languages []string `json:"languages,omitempty" jsonschema:"Programming languages to filter by (optional). Leave empty to get conventions for all languages. Examples: go, javascript, typescript, python, java"` + Categories []string `json:"categories,omitempty" jsonschema:"Filter by categories (optional). Leave empty or use [\"all\"] to fetch all categories. Example: [\"security\", \"style\"]"` + Languages []string `json:"languages,omitempty" jsonschema:"Programming languages to filter by (optional). Leave empty to get conventions for all languages. Examples: go, javascript, typescript, python, java"` } // ValidateCodeInput represents the input schema for the validate_code tool (go-sdk). @@ -303,11 +303,11 @@ func (s *Server) runStdioWithSDK(ctx context.Context) error { // Tool: list_convention sdkmcp.AddTool(server, &sdkmcp.Tool{ Name: "list_convention", - Description: "[MANDATORY BEFORE CODING] List project conventions BEFORE writing any code to ensure compliance from the start. Filter by category or languages.", + Description: "[MANDATORY BEFORE CODING] List project conventions BEFORE writing any code to ensure compliance from the start. Filter by categories or languages.", }, func(ctx context.Context, req *sdkmcp.CallToolRequest, input ListConventionInput) (*sdkmcp.CallToolResult, map[string]any, error) { params := map[string]any{ - "category": input.Category, - "languages": input.Languages, + "categories": input.Categories, + "languages": input.Languages, } result, rpcErr := s.handleListConvention(params) if rpcErr != nil { @@ -446,8 +446,8 @@ func (s *Server) runStdioWithSDK(ctx context.Context) error { // QueryConventionsRequest is a request to query conventions. type QueryConventionsRequest struct { - Category string `json:"category"` // optional; use "all" or empty to fetch all categories - Languages []string `json:"languages"` // optional; empty means all languages + Categories []string `json:"categories"` // optional; use ["all"] or empty to fetch all categories + Languages []string `json:"languages"` // optional; empty means all languages } // ConventionItem is a convention item. @@ -479,41 +479,23 @@ func (s *Server) handleListConvention(params map[string]interface{}) (interface{ } // Apply defaults for missing parameters - // If category is empty or "all", return all categories - if strings.TrimSpace(req.Category) == "" || strings.EqualFold(req.Category, "all") { - req.Category = "" - } + // If categories contains "all" or is empty, return all categories + req.Categories = normalizeCategories(req.Categories) // If languages is empty, return all languages // This is more user-friendly than requiring the parameter conventions := s.filterConventions(req) - // Format conventions as readable text for MCP response - var textContent string - if len(conventions) == 0 { - textContent = "No conventions found for the specified criteria." - } else { - textContent = fmt.Sprintf("Found %d convention(s):\n\n", len(conventions)) - for i, conv := range conventions { - textContent += fmt.Sprintf("%d. [%s] %s\n", i+1, conv.Severity, conv.ID) - textContent += fmt.Sprintf(" Category: %s\n", conv.Category) - textContent += fmt.Sprintf(" Description: %s\n", conv.Description) - if conv.Message != "" && conv.Message != conv.Description { - textContent += fmt.Sprintf(" Message: %s\n", conv.Message) - } - textContent += "\n" - } - } + // Build structured prompt-style response + textContent := s.buildConventionPrompt(conventions, req) // Add RBAC information if available rbacInfo := s.getRBACInfo() if rbacInfo != "" { - textContent += "\n\n" + rbacInfo + textContent += "\n" + rbacInfo + "\n" } - textContent += "\n✓ Next Step: Implement your code following these conventions. After completion, MUST call validate_code to verify compliance." - // Return MCP-compliant response with content array return map[string]interface{}{ "content": []map[string]interface{}{ @@ -532,7 +514,8 @@ func (s *Server) filterConventions(req QueryConventionsRequest) []ConventionItem // If UserPolicy is loaded, use natural language rules if s.userPolicy != nil { for _, rule := range s.userPolicy.Rules { - if req.Category != "" && rule.Category != req.Category { + // Check category filter + if !matchesCategories(rule.Category, req.Categories) { continue } @@ -576,7 +559,8 @@ func (s *Server) filterConventions(req QueryConventionsRequest) []ConventionItem continue } - if req.Category != "" && rule.Category != req.Category { + // Check category filter (supports multiple categories) + if !matchesCategories(rule.Category, req.Categories) { continue } @@ -792,6 +776,46 @@ func containsAny(haystack, needles []string) bool { return false } +// normalizeCategories normalizes the categories input. +// Returns nil if categories is empty or contains "all" (meaning fetch all). +func normalizeCategories(categories []string) []string { + if len(categories) == 0 { + return nil + } + // Check if any category is "all" + for _, cat := range categories { + if strings.EqualFold(strings.TrimSpace(cat), "all") { + return nil + } + } + // Trim whitespace from each category + result := make([]string, 0, len(categories)) + for _, cat := range categories { + trimmed := strings.TrimSpace(cat) + if trimmed != "" { + result = append(result, trimmed) + } + } + if len(result) == 0 { + return nil + } + return result +} + +// matchesCategories checks if the given category matches any of the requested categories. +// Returns true if categories is nil/empty (meaning all categories match). +func matchesCategories(ruleCategory string, requestedCategories []string) bool { + if len(requestedCategories) == 0 { + return true + } + for _, cat := range requestedCategories { + if ruleCategory == cat { + return true + } + } + return false +} + // getValidationPolicy returns CodePolicy for validation. func (s *Server) getValidationPolicy() (*schema.CodePolicy, error) { if s.codePolicy != nil { @@ -1785,3 +1809,59 @@ func (s *Server) buildConventionBatchResponse(action string, succeeded []string, }, } } + +// buildConventionPrompt builds a structured prompt-style response for AI coding assistants. +func (s *Server) buildConventionPrompt(conventions []ConventionItem, req QueryConventionsRequest) string { + var sb strings.Builder + + // Header + sb.WriteString("# Project Coding Conventions\n\n") + + if len(conventions) == 0 { + sb.WriteString("No conventions found for the specified criteria.\n\n") + sb.WriteString("Use `list_category` to see available categories.\n") + return sb.String() + } + + // Summary with filter context + sb.WriteString(fmt.Sprintf("Total: %d convention(s)\n", len(conventions))) + if len(req.Categories) > 0 { + sb.WriteString(fmt.Sprintf("Filtered by: %s\n", strings.Join(req.Categories, ", "))) + } + if len(req.Languages) > 0 { + sb.WriteString(fmt.Sprintf("Languages: %s\n", strings.Join(req.Languages, ", "))) + } + sb.WriteString("\n---\n\n") + + // Group by category + categoryMap := make(map[string][]ConventionItem) + categoryOrder := []string{} + for _, conv := range conventions { + cat := conv.Category + if cat == "" { + cat = "General" + } + if _, exists := categoryMap[cat]; !exists { + categoryOrder = append(categoryOrder, cat) + } + categoryMap[cat] = append(categoryMap[cat], conv) + } + + // Render each category + for _, cat := range categoryOrder { + sb.WriteString(fmt.Sprintf("## %s\n\n", cat)) + for _, conv := range categoryMap[cat] { + sb.WriteString(fmt.Sprintf("**%s** `[%s]`\n", conv.ID, conv.Severity)) + sb.WriteString(fmt.Sprintf(" %s\n\n", conv.Description)) + } + } + + // Instructions section + sb.WriteString("---\n\n") + sb.WriteString("## Instructions\n\n") + sb.WriteString("1. Apply all conventions above when writing code(without configuration files)\n") + sb.WriteString("2. Prioritize error-level rules (these block validation)\n") + sb.WriteString("3. After implementation, call `validate_code` to verify compliance\n") + + return sb.String() +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index f43135c..bab30c5 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -73,8 +73,8 @@ func TestQueryConventions(t *testing.T) { t.Run("query all categories for javascript", func(t *testing.T) { params := map[string]interface{}{ - "category": "all", - "languages": []interface{}{"javascript"}, + "categories": []interface{}{"all"}, + "languages": []interface{}{"javascript"}, } result, rpcErr := server.handleListConvention(params) @@ -96,8 +96,8 @@ func TestQueryConventions(t *testing.T) { t.Run("query documentation category for javascript", func(t *testing.T) { params := map[string]interface{}{ - "category": "documentation", - "languages": []interface{}{"javascript"}, + "categories": []interface{}{"documentation"}, + "languages": []interface{}{"javascript"}, } result, rpcErr := server.handleListConvention(params) @@ -118,8 +118,8 @@ func TestQueryConventions(t *testing.T) { t.Run("query security category for typescript", func(t *testing.T) { params := map[string]interface{}{ - "category": "security", - "languages": []interface{}{"typescript"}, + "categories": []interface{}{"security"}, + "languages": []interface{}{"typescript"}, } result, rpcErr := server.handleListConvention(params) @@ -139,8 +139,8 @@ func TestQueryConventions(t *testing.T) { t.Run("query with unsupported language", func(t *testing.T) { params := map[string]interface{}{ - "category": "all", - "languages": []interface{}{"python"}, + "categories": []interface{}{"all"}, + "languages": []interface{}{"python"}, } result, rpcErr := server.handleListConvention(params) @@ -159,8 +159,8 @@ func TestQueryConventions(t *testing.T) { t.Run("rule without severity uses defaults", func(t *testing.T) { params := map[string]interface{}{ - "category": "style", - "languages": []interface{}{"javascript"}, + "categories": []interface{}{"style"}, + "languages": []interface{}{"javascript"}, } result, rpcErr := server.handleListConvention(params) @@ -199,7 +199,7 @@ func TestQueryConventions(t *testing.T) { t.Run("only category specified", func(t *testing.T) { params := map[string]interface{}{ - "category": "security", + "categories": []interface{}{"security"}, } result, rpcErr := server.handleListConvention(params)