Skip to content

Commit 59661df

Browse files
committed
refactor(tool): use ToolType enum for tool name checks #453
Replaces string literals with ToolType enum references for tool name comparisons and pattern matching across the codebase, improving type safety and consistency. Removes legacy fallback logic for unknown tools.
1 parent 1c53270 commit 59661df

File tree

8 files changed

+21
-59
lines changed

8 files changed

+21
-59
lines changed

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/CodingAgent.kt

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import cc.unitmesh.agent.orchestrator.ToolOrchestrator
99
import cc.unitmesh.agent.policy.DefaultPolicyEngine
1010
import cc.unitmesh.agent.render.CodingAgentRenderer
1111
import cc.unitmesh.agent.render.DefaultCodingAgentRenderer
12+
import cc.unitmesh.agent.subagent.CodebaseInvestigatorAgent
1213
import cc.unitmesh.agent.subagent.ErrorRecoveryAgent
1314
import cc.unitmesh.agent.subagent.LogSummaryAgent
1415
import cc.unitmesh.agent.tool.ToolResult
@@ -53,14 +54,14 @@ class CodingAgent(
5354
shellExecutor = shellExecutor ?: DefaultShellExecutor()
5455
)
5556

56-
// New orchestration components
5757
private val policyEngine = DefaultPolicyEngine()
5858
private val toolOrchestrator = ToolOrchestrator(toolRegistry, policyEngine, renderer)
5959

60-
// SubAgents
6160
private val errorRecoveryAgent = ErrorRecoveryAgent(projectPath, llmService)
6261
private val logSummaryAgent = LogSummaryAgent(llmService, threshold = 2000)
6362

63+
private val codebaseInvestigatorAgent = CodebaseInvestigatorAgent(projectPath, llmService)
64+
6465
// 执行器
6566
private val executor = CodingAgentExecutor(
6667
projectPath = projectPath,
@@ -74,25 +75,22 @@ class CodingAgent(
7475
// 注册 SubAgents(作为 Tools)
7576
registerTool(errorRecoveryAgent)
7677
registerTool(logSummaryAgent)
78+
registerTool(codebaseInvestigatorAgent)
7779

78-
// ToolRegistry 已经在 init 中注册了内置 tools(read-file, write-file, shell, glob)
80+
/// TODO 注册 MCP Tools
7981
}
8082

8183
override suspend fun execute(
8284
input: AgentTask,
8385
onProgress: (String) -> Unit
8486
): ToolResult.AgentResult {
85-
// 初始化工作空间
8687
initializeWorkspace(input.projectPath)
8788

88-
// 构建系统提示词
8989
val context = buildContext(input)
9090
val systemPrompt = buildSystemPrompt(context)
9191

92-
// 使用执行器执行任务
9392
val result = executor.execute(input, systemPrompt, onProgress)
9493

95-
// 返回结果
9694
return ToolResult.AgentResult(
9795
success = result.success,
9896
content = result.message,

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/executor/CodingAgentExecutor.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class CodingAgentExecutor(
152152
ToolType.ReadFile, ToolType.WriteFile -> 3
153153
ToolType.Shell -> 2
154154
else -> when (toolName) {
155-
"read-file", "write-file" -> 3
155+
ToolType.ReadFile.name, ToolType.WriteFile.name -> 3
156156
"shell" -> 2
157157
else -> 2
158158
}

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/orchestrator/ToolOrchestrator.kt

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,6 @@ class ToolOrchestrator(
172172
ToolType.WriteFile -> executeWriteFileTool(tool, params, basicContext)
173173
ToolType.Glob -> executeGlobTool(tool, params, basicContext)
174174
ToolType.Grep -> executeGrepTool(tool, params, basicContext)
175-
null -> {
176-
// Fallback for unknown tools or legacy string matching
177-
when (toolName) {
178-
"shell" -> executeShellTool(tool, params, basicContext)
179-
"read-file" -> executeReadFileTool(tool, params, basicContext)
180-
"write-file" -> executeWriteFileTool(tool, params, basicContext)
181-
"glob" -> executeGlobTool(tool, params, basicContext)
182-
"grep" -> executeGrepTool(tool, params, basicContext)
183-
else -> ToolResult.Error("Unknown tool: $toolName")
184-
}
185-
}
186175
else -> ToolResult.Error("Tool not implemented: ${toolType.displayName}")
187176
}
188177
}

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/parser/ToolCallParser.kt

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ class ToolCallParser {
2626
if (firstBlock != null) {
2727
val toolCall = parseToolCallFromDevinBlock(firstBlock)
2828
if (toolCall != null) {
29-
// For write-file tools, try to extract content from the surrounding context
30-
if (toolCall.toolName == "write-file" && !toolCall.params.containsKey("content")) {
29+
if (toolCall.toolName == ToolType.WriteFile.name && !toolCall.params.containsKey("content")) {
3130
val contentFromContext = extractContentFromContext(llmResponse, firstBlock)
3231
if (contentFromContext != null) {
3332
val updatedParams = toolCall.params.toMutableMap()
@@ -46,10 +45,6 @@ class ToolCallParser {
4645
return toolCalls
4746
}
4847

49-
fun parseDevinBlocks(content: String): List<DevinBlock> {
50-
return devinParser.extractDevinBlocks(content)
51-
}
52-
5348
private fun parseToolCallFromDevinBlock(block: DevinBlock): ToolCall? {
5449
val lines = block.content.lines()
5550

@@ -65,10 +60,7 @@ class ToolCallParser {
6560

6661
return null
6762
}
68-
69-
/**
70-
* Parse a direct tool call (without DevIn blocks)
71-
*/
63+
7264
private fun parseDirectToolCall(response: String): ToolCall? {
7365
val toolPattern = Regex("""/(\w+(?:-\w+)*)(.*)""", RegexOption.MULTILINE)
7466
val match = toolPattern.find(response) ?: return null
@@ -79,9 +71,6 @@ class ToolCallParser {
7971
return parseToolCallFromLine("/$toolName $rest")
8072
}
8173

82-
/**
83-
* Parse a tool call from a single line
84-
*/
8574
private fun parseToolCallFromLine(line: String): ToolCall? {
8675
val toolPattern = Regex("""/(\w+(?:-\w+)*)(.*)""")
8776
val match = toolPattern.find(line) ?: return null
@@ -93,10 +82,7 @@ class ToolCallParser {
9382

9483
return ToolCall.create(toolName, params)
9584
}
96-
97-
/**
98-
* Parse parameters from the rest of the tool call line
99-
*/
85+
10086
private fun parseParameters(toolName: String, rest: String): Map<String, Any> {
10187
val params = mutableMapOf<String, Any>()
10288

@@ -108,10 +94,7 @@ class ToolCallParser {
10894

10995
return params
11096
}
111-
112-
/**
113-
* Parse key="value" style parameters
114-
*/
97+
11598
private fun parseKeyValueParameters(rest: String, params: MutableMap<String, Any>) {
11699
val remaining = rest.toCharArray().toList()
117100
var i = 0
@@ -157,11 +140,9 @@ class ToolCallParser {
157140
* Parse simple parameter (single value without key)
158141
*/
159142
private fun parseSimpleParameter(toolName: String, rest: String, params: MutableMap<String, Any>) {
160-
if (toolName == "shell") {
143+
if (toolName == ToolType.Shell.name) {
161144
params["command"] = escapeProcessor.processEscapeSequences(rest.trim())
162-
} else if (toolName == "write-file") {
163-
// For write-file, if only one parameter is provided, it's the path
164-
// The content should be provided in a separate parameter or in the LLM context
145+
} else if (toolName == ToolType.WriteFile.name) {
165146
val firstLine = rest.lines().firstOrNull()?.trim()
166147
if (firstLine != null && firstLine.isNotEmpty()) {
167148
params["path"] = escapeProcessor.processEscapeSequences(firstLine)

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/policy/DefaultPolicyEngine.kt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cc.unitmesh.agent.policy
22

33
import cc.unitmesh.agent.orchestrator.ToolExecutionContext
44
import cc.unitmesh.agent.state.ToolCall
5+
import cc.unitmesh.agent.tool.ToolType
56

67
/**
78
* Default implementation of PolicyEngine
@@ -131,7 +132,7 @@ class DefaultPolicyEngine : PolicyEngine {
131132
addRule(PolicyRule(
132133
name = "allow_read_file",
133134
description = "Allow reading files",
134-
toolPattern = "read-file",
135+
toolPattern = ToolType.ReadFile.name,
135136
decision = PolicyDecision.ALLOW,
136137
riskLevel = RiskLevel.LOW,
137138
priority = 10
@@ -140,7 +141,7 @@ class DefaultPolicyEngine : PolicyEngine {
140141
addRule(PolicyRule(
141142
name = "allow_write_file",
142143
description = "Allow writing files",
143-
toolPattern = "write-file",
144+
toolPattern = ToolType.WriteFile.name,
144145
decision = PolicyDecision.ALLOW,
145146
riskLevel = RiskLevel.MEDIUM,
146147
priority = 10
@@ -149,7 +150,7 @@ class DefaultPolicyEngine : PolicyEngine {
149150
addRule(PolicyRule(
150151
name = "allow_glob",
151152
description = "Allow file globbing",
152-
toolPattern = "glob",
153+
toolPattern = ToolType.Glob.name,
153154
decision = PolicyDecision.ALLOW,
154155
riskLevel = RiskLevel.LOW,
155156
priority = 10
@@ -158,7 +159,7 @@ class DefaultPolicyEngine : PolicyEngine {
158159
addRule(PolicyRule(
159160
name = "allow_grep",
160161
description = "Allow text searching",
161-
toolPattern = "grep",
162+
toolPattern = ToolType.Grep.name,
162163
decision = PolicyDecision.ALLOW,
163164
riskLevel = RiskLevel.LOW,
164165
priority = 10
@@ -168,7 +169,7 @@ class DefaultPolicyEngine : PolicyEngine {
168169
addRule(PolicyRule(
169170
name = "allow_shell",
170171
description = "Allow shell commands",
171-
toolPattern = "shell",
172+
toolPattern = ToolType.Shell.name,
172173
decision = PolicyDecision.ALLOW,
173174
riskLevel = RiskLevel.HIGH,
174175
priority = 5

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/recovery/ErrorRecoveryManager.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cc.unitmesh.agent.recovery
22

33
import cc.unitmesh.agent.subagent.ErrorRecoveryAgent
4+
import cc.unitmesh.agent.tool.ToolType
45
import cc.unitmesh.llm.KoogLLMService
56

67
/**
@@ -60,12 +61,12 @@ class ErrorRecoveryManager(private val projectPath: String, private val llmServi
6061
*/
6162
private fun shouldAttemptRecovery(toolName: String, errorMessage: String): Boolean {
6263
// 对于 shell 命令错误,总是尝试恢复
63-
if (toolName == "shell") {
64+
if (toolName == ToolType.Shell.name) {
6465
return true
6566
}
6667

6768
// 对于文件操作错误,如果是权限或路径问题,尝试恢复
68-
if (toolName in listOf("write-file", "read-file")) {
69+
if (toolName in listOf(ToolType.ReadFile.name, ToolType.WriteFile.name)) {
6970
val recoverableErrors = listOf(
7071
"permission denied",
7172
"no such file or directory",

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/CodebaseInvestigatorAgent.kt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import cc.unitmesh.agent.model.PromptConfig
66
import cc.unitmesh.agent.model.RunConfig
77
import cc.unitmesh.agent.model.ToolConfig
88
import cc.unitmesh.agent.tool.ToolResult
9-
import cc.unitmesh.agent.tool.ToolNames
109
import cc.unitmesh.agent.tool.ToolType
1110
import cc.unitmesh.llm.KoogLLMService
1211
import cc.unitmesh.llm.ModelConfig
@@ -68,8 +67,6 @@ class CodebaseInvestigatorAgent(
6867
)
6968
)
7069
) {
71-
72-
// Simple in-memory cache for analysis results
7370
private var analysisCache: Map<String, String> = emptyMap()
7471

7572
override fun validateInput(input: Map<String, Any>): InvestigationContext {
@@ -141,17 +138,13 @@ class CodebaseInvestigatorAgent(
141138
}
142139
}
143140

144-
/**
145-
* Process the investigation query using simple text analysis
146-
*/
147141
private suspend fun processInvestigationQuery(
148142
input: InvestigationContext,
149143
onProgress: (String) -> Unit
150144
): InvestigationResult {
151145
val findings = mutableListOf<String>()
152146
val recommendations = mutableListOf<String>()
153147

154-
// Analyze query intent and extract relevant information
155148
val queryAnalysis = analyzeQuery(input.query)
156149

157150
when (input.scope) {

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/ErrorRecoveryAgent.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ $context
280280
*/
281281
private fun parseRecoveryResponse(response: String): RecoveryResult {
282282
return try {
283-
// Try to extract JSON from response
284283
val jsonMatch = Regex("```json\\s*([\\s\\S]*?)\\s*```").find(response)?.groupValues?.get(1)
285284
?: Regex("\\{[\\s\\S]*?\\}").find(response)?.value
286285

0 commit comments

Comments
 (0)