Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 112 additions & 32 deletions internal/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ type RPCError struct {

// ListConventionInput represents the input schema for the list_convention tool (go-sdk).
type ListConventionInput struct {
Category string `json:"category,omitempty" jsonschema:"Filter by category (optional). Use 'all' or leave empty to fetch all categories. Options: security, style, documentation, error_handling, architecture, performance, testing"`
Languages []string `json:"languages,omitempty" jsonschema:"Programming languages to filter by (optional). Leave empty to get conventions for all languages. Examples: go, javascript, typescript, python, java"`
Categories []string `json:"categories,omitempty" jsonschema:"Filter by categories (optional). Leave empty or use [\"all\"] to fetch all categories. Example: [\"security\", \"style\"]"`
Languages []string `json:"languages,omitempty" jsonschema:"Programming languages to filter by (optional). Leave empty to get conventions for all languages. Examples: go, javascript, typescript, python, java"`
}

// ValidateCodeInput represents the input schema for the validate_code tool (go-sdk).
Expand Down Expand Up @@ -303,11 +303,11 @@ func (s *Server) runStdioWithSDK(ctx context.Context) error {
// Tool: list_convention
sdkmcp.AddTool(server, &sdkmcp.Tool{
Name: "list_convention",
Description: "[MANDATORY BEFORE CODING] List project conventions BEFORE writing any code to ensure compliance from the start. Filter by category or languages.",
Description: "[MANDATORY BEFORE CODING] List project conventions BEFORE writing any code to ensure compliance from the start. Filter by categories or languages.",
}, func(ctx context.Context, req *sdkmcp.CallToolRequest, input ListConventionInput) (*sdkmcp.CallToolResult, map[string]any, error) {
params := map[string]any{
"category": input.Category,
"languages": input.Languages,
"categories": input.Categories,
"languages": input.Languages,
}
result, rpcErr := s.handleListConvention(params)
if rpcErr != nil {
Expand Down Expand Up @@ -446,8 +446,8 @@ func (s *Server) runStdioWithSDK(ctx context.Context) error {

// QueryConventionsRequest is a request to query conventions.
type QueryConventionsRequest struct {
Category string `json:"category"` // optional; use "all" or empty to fetch all categories
Languages []string `json:"languages"` // optional; empty means all languages
Categories []string `json:"categories"` // optional; use ["all"] or empty to fetch all categories
Languages []string `json:"languages"` // optional; empty means all languages
}

// ConventionItem is a convention item.
Expand Down Expand Up @@ -479,41 +479,23 @@ func (s *Server) handleListConvention(params map[string]interface{}) (interface{
}

// Apply defaults for missing parameters
// If category is empty or "all", return all categories
if strings.TrimSpace(req.Category) == "" || strings.EqualFold(req.Category, "all") {
req.Category = ""
}
// If categories contains "all" or is empty, return all categories
req.Categories = normalizeCategories(req.Categories)

// If languages is empty, return all languages
// This is more user-friendly than requiring the parameter

conventions := s.filterConventions(req)

// Format conventions as readable text for MCP response
var textContent string
if len(conventions) == 0 {
textContent = "No conventions found for the specified criteria."
} else {
textContent = fmt.Sprintf("Found %d convention(s):\n\n", len(conventions))
for i, conv := range conventions {
textContent += fmt.Sprintf("%d. [%s] %s\n", i+1, conv.Severity, conv.ID)
textContent += fmt.Sprintf(" Category: %s\n", conv.Category)
textContent += fmt.Sprintf(" Description: %s\n", conv.Description)
if conv.Message != "" && conv.Message != conv.Description {
textContent += fmt.Sprintf(" Message: %s\n", conv.Message)
}
textContent += "\n"
}
}
// Build structured prompt-style response
textContent := s.buildConventionPrompt(conventions, req)

// Add RBAC information if available
rbacInfo := s.getRBACInfo()
if rbacInfo != "" {
textContent += "\n\n" + rbacInfo
textContent += "\n" + rbacInfo + "\n"
}

textContent += "\n✓ Next Step: Implement your code following these conventions. After completion, MUST call validate_code to verify compliance."

// Return MCP-compliant response with content array
return map[string]interface{}{
"content": []map[string]interface{}{
Expand All @@ -532,7 +514,8 @@ func (s *Server) filterConventions(req QueryConventionsRequest) []ConventionItem
// If UserPolicy is loaded, use natural language rules
if s.userPolicy != nil {
for _, rule := range s.userPolicy.Rules {
if req.Category != "" && rule.Category != req.Category {
// Check category filter
if !matchesCategories(rule.Category, req.Categories) {
continue
}

Expand Down Expand Up @@ -576,7 +559,8 @@ func (s *Server) filterConventions(req QueryConventionsRequest) []ConventionItem
continue
}

if req.Category != "" && rule.Category != req.Category {
// Check category filter (supports multiple categories)
if !matchesCategories(rule.Category, req.Categories) {
continue
}

Expand Down Expand Up @@ -792,6 +776,46 @@ func containsAny(haystack, needles []string) bool {
return false
}

// normalizeCategories normalizes the categories input.
// Returns nil if categories is empty or contains "all" (meaning fetch all).
func normalizeCategories(categories []string) []string {
if len(categories) == 0 {
return nil
}
// Check if any category is "all"
for _, cat := range categories {
if strings.EqualFold(strings.TrimSpace(cat), "all") {
return nil
}
}
// Trim whitespace from each category
result := make([]string, 0, len(categories))
for _, cat := range categories {
trimmed := strings.TrimSpace(cat)
if trimmed != "" {
result = append(result, trimmed)
}
}
if len(result) == 0 {
return nil
}
return result
}

// matchesCategories checks if the given category matches any of the requested categories.
// Returns true if categories is nil/empty (meaning all categories match).
func matchesCategories(ruleCategory string, requestedCategories []string) bool {
if len(requestedCategories) == 0 {
return true
}
for _, cat := range requestedCategories {
if ruleCategory == cat {
return true
}
}
return false
}

// getValidationPolicy returns CodePolicy for validation.
func (s *Server) getValidationPolicy() (*schema.CodePolicy, error) {
if s.codePolicy != nil {
Expand Down Expand Up @@ -1785,3 +1809,59 @@ func (s *Server) buildConventionBatchResponse(action string, succeeded []string,
},
}
}

// buildConventionPrompt builds a structured prompt-style response for AI coding assistants.
func (s *Server) buildConventionPrompt(conventions []ConventionItem, req QueryConventionsRequest) string {
var sb strings.Builder

// Header
sb.WriteString("# Project Coding Conventions\n\n")

if len(conventions) == 0 {
sb.WriteString("No conventions found for the specified criteria.\n\n")
sb.WriteString("Use `list_category` to see available categories.\n")
return sb.String()
}

// Summary with filter context
sb.WriteString(fmt.Sprintf("Total: %d convention(s)\n", len(conventions)))
if len(req.Categories) > 0 {
sb.WriteString(fmt.Sprintf("Filtered by: %s\n", strings.Join(req.Categories, ", ")))
}
if len(req.Languages) > 0 {
sb.WriteString(fmt.Sprintf("Languages: %s\n", strings.Join(req.Languages, ", ")))
}
sb.WriteString("\n---\n\n")

// Group by category
categoryMap := make(map[string][]ConventionItem)
categoryOrder := []string{}
for _, conv := range conventions {
cat := conv.Category
if cat == "" {
cat = "General"
}
if _, exists := categoryMap[cat]; !exists {
categoryOrder = append(categoryOrder, cat)
}
categoryMap[cat] = append(categoryMap[cat], conv)
}

// Render each category
for _, cat := range categoryOrder {
sb.WriteString(fmt.Sprintf("## %s\n\n", cat))
for _, conv := range categoryMap[cat] {
sb.WriteString(fmt.Sprintf("**%s** `[%s]`\n", conv.ID, conv.Severity))
sb.WriteString(fmt.Sprintf(" %s\n\n", conv.Description))
}
}

// Instructions section
sb.WriteString("---\n\n")
sb.WriteString("## Instructions\n\n")
sb.WriteString("1. Apply all conventions above when writing code(without configuration files)\n")
Copy link

Copilot AI Dec 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space before the opening parenthesis in "code(without configuration files)". Should be "code (without configuration files)".

Suggested change
sb.WriteString("1. Apply all conventions above when writing code(without configuration files)\n")
sb.WriteString("1. Apply all conventions above when writing code (without configuration files)\n")

Copilot uses AI. Check for mistakes.
sb.WriteString("2. Prioritize error-level rules (these block validation)\n")
sb.WriteString("3. After implementation, call `validate_code` to verify compliance\n")

return sb.String()
}
22 changes: 11 additions & 11 deletions internal/mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ func TestQueryConventions(t *testing.T) {

t.Run("query all categories for javascript", func(t *testing.T) {
params := map[string]interface{}{
"category": "all",
"languages": []interface{}{"javascript"},
"categories": []interface{}{"all"},
"languages": []interface{}{"javascript"},
}

result, rpcErr := server.handleListConvention(params)
Expand All @@ -96,8 +96,8 @@ func TestQueryConventions(t *testing.T) {

t.Run("query documentation category for javascript", func(t *testing.T) {
params := map[string]interface{}{
"category": "documentation",
"languages": []interface{}{"javascript"},
"categories": []interface{}{"documentation"},
"languages": []interface{}{"javascript"},
}

result, rpcErr := server.handleListConvention(params)
Expand All @@ -118,8 +118,8 @@ func TestQueryConventions(t *testing.T) {

t.Run("query security category for typescript", func(t *testing.T) {
params := map[string]interface{}{
"category": "security",
"languages": []interface{}{"typescript"},
"categories": []interface{}{"security"},
"languages": []interface{}{"typescript"},
}

result, rpcErr := server.handleListConvention(params)
Expand All @@ -139,8 +139,8 @@ func TestQueryConventions(t *testing.T) {

t.Run("query with unsupported language", func(t *testing.T) {
params := map[string]interface{}{
"category": "all",
"languages": []interface{}{"python"},
"categories": []interface{}{"all"},
"languages": []interface{}{"python"},
}

result, rpcErr := server.handleListConvention(params)
Expand All @@ -159,8 +159,8 @@ func TestQueryConventions(t *testing.T) {

t.Run("rule without severity uses defaults", func(t *testing.T) {
params := map[string]interface{}{
"category": "style",
"languages": []interface{}{"javascript"},
"categories": []interface{}{"style"},
"languages": []interface{}{"javascript"},
}

result, rpcErr := server.handleListConvention(params)
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestQueryConventions(t *testing.T) {

t.Run("only category specified", func(t *testing.T) {
params := map[string]interface{}{
"category": "security",
"categories": []interface{}{"security"},
}

result, rpcErr := server.handleListConvention(params)
Expand Down