diff --git a/apps/backend/src/agents/tools/execute-python.ts b/apps/backend/src/agents/tools/execute-python.ts index ab1beb1e..3cefff63 100644 --- a/apps/backend/src/agents/tools/execute-python.ts +++ b/apps/backend/src/agents/tools/execute-python.ts @@ -1,9 +1,9 @@ import { executePython as schemas } from '@nao/shared/tools'; -import { tool } from 'ai'; import fs from 'fs'; import path from 'path'; -import { getProjectFolder, isWithinProjectFolder, toVirtualPath } from '../../utils/tools'; +import { createTool } from '../../types/tools'; +import { isWithinProjectFolder, toVirtualPath } from '../../utils/tools'; // @pydantic/monty uses native bindings that aren't available on all platforms // (e.g. no Linux ARM64 binary). Load lazily so the server can still start. @@ -21,14 +21,17 @@ const RESOURCE_LIMITS = { maxRecursionDepth: 500, }; -export async function executePython({ code, inputs }: schemas.Input): Promise { +async function executePython( + { code, inputs }: schemas.Input, + { projectFolder }: { projectFolder: string }, +): Promise { if (!montyModule) { throw new Error('Python execution is not available on this platform'); } const { Monty, MontyRuntimeError, MontySnapshot, MontySyntaxError, MontyTypingError } = montyModule; const inputNames = inputs ? Object.keys(inputs) : []; - const virtualFS = createVirtualFS(); + const virtualFS = createVirtualFS(projectFolder); let monty: InstanceType; try { @@ -123,15 +126,10 @@ function findAllFiles(dir: string, projectFolder: string): schemas.VirtualFile[] return files; } -function loadProjectFiles(): schemas.VirtualFile[] { - const projectFolder = getProjectFolder(); - return findAllFiles(projectFolder, projectFolder); -} - -function createVirtualFS(): Map { +function createVirtualFS(projectFolder: string): Map { const vfs = new Map(); - const projectFiles = loadProjectFiles(); + const projectFiles = findAllFiles(projectFolder, projectFolder); for (const file of projectFiles) { vfs.set(file.path, file.content); } @@ -145,10 +143,12 @@ const EXTERNAL_FUNCTION_NAMES = schemas.EXTERNAL_FUNCTIONS.map((f) => f.name); export const isPythonAvailable = montyModule !== null; export default montyModule - ? tool({ + ? createTool({ description: schemas.description, inputSchema: schemas.inputSchema, outputSchema: schemas.outputSchema, - execute: executePython, + execute: async (input, context) => { + return executePython(input, { projectFolder: context.projectFolder }); + }, }) : null; diff --git a/apps/backend/src/agents/tools/execute-sql.ts b/apps/backend/src/agents/tools/execute-sql.ts index be918ce8..ed6d120b 100644 --- a/apps/backend/src/agents/tools/execute-sql.ts +++ b/apps/backend/src/agents/tools/execute-sql.ts @@ -1,24 +1,15 @@ import type { executeSql } from '@nao/shared/tools'; import { executeSql as schemas } from '@nao/shared/tools'; -import { tool } from 'ai'; import { ExecuteSqlOutput, renderToModelOutput } from '../../components/tool-outputs'; import { env } from '../../env'; -import { getProjectFolder } from '../../utils/tools'; +import { createTool, type ToolContext } from '../../types/tools'; -export default tool({ - description: - 'Execute a SQL query against the connected database and return the results. If multiple databases are configured, specify the database_id.', - inputSchema: schemas.InputSchema, - outputSchema: schemas.OutputSchema, - - execute: executeQuery, - - toModelOutput: ({ output }) => renderToModelOutput(ExecuteSqlOutput({ output }), output), -}); - -export async function executeQuery({ sql_query, database_id }: executeSql.Input): Promise { - const naoProjectFolder = getProjectFolder(); +export async function executeQuery( + { sql_query, database_id }: executeSql.Input, + context: ToolContext, +): Promise { + const naoProjectFolder = context.projectFolder; const response = await fetch(`http://localhost:${env.FASTAPI_PORT}/execute_sql`, { method: 'POST', @@ -44,3 +35,12 @@ export async function executeQuery({ sql_query, database_id }: executeSql.Input) id: `query_${crypto.randomUUID().slice(0, 8)}`, }; } + +export default createTool({ + description: + 'Execute a SQL query against the connected database and return the results. If multiple databases are configured, specify the database_id.', + inputSchema: schemas.InputSchema, + outputSchema: schemas.OutputSchema, + execute: executeQuery, + toModelOutput: ({ output }) => renderToModelOutput(ExecuteSqlOutput({ output }), output), +}); diff --git a/apps/backend/src/agents/tools/grep.ts b/apps/backend/src/agents/tools/grep.ts index 8c97abe8..937f1ad1 100644 --- a/apps/backend/src/agents/tools/grep.ts +++ b/apps/backend/src/agents/tools/grep.ts @@ -1,17 +1,11 @@ import { grep } from '@nao/shared/tools'; -import { tool } from 'ai'; import { spawn } from 'child_process'; import fs from 'fs'; import path from 'path'; import { GrepOutput, renderToModelOutput } from '../../components/tool-outputs'; -import { - getProjectFolder, - isWithinProjectFolder, - loadNaoignorePatterns, - toRealPath, - toVirtualPath, -} from '../../utils/tools'; +import { createTool } from '../../types/tools'; +import { isWithinProjectFolder, loadNaoignorePatterns, toRealPath, toVirtualPath } from '../../utils/tools'; /** * Gets the path to the ripgrep binary. @@ -53,13 +47,15 @@ interface RipgrepMatch { context_after?: string[]; } -export default tool({ +export default createTool({ description: 'Search for text patterns in files using ripgrep. Supports regex patterns and respects .gitignore.', inputSchema: grep.InputSchema, outputSchema: grep.OutputSchema, - - execute: async ({ pattern, path: searchPath, glob, case_insensitive, context_lines, max_results = 100 }) => { - const projectFolder = getProjectFolder(); + execute: async ( + { pattern, path: searchPath, glob, case_insensitive, context_lines, max_results = 100 }, + context, + ) => { + const projectFolder = context.projectFolder; const rgPath = getRipgrepPath(); // Determine the search path diff --git a/apps/backend/src/agents/tools/list.ts b/apps/backend/src/agents/tools/list.ts index f1bfccd0..216b383a 100644 --- a/apps/backend/src/agents/tools/list.ts +++ b/apps/backend/src/agents/tools/list.ts @@ -1,18 +1,17 @@ import { list } from '@nao/shared/tools'; -import { tool } from 'ai'; import fs from 'fs/promises'; import path from 'path'; import { ListOutput, renderToModelOutput } from '../../components/tool-outputs'; -import { getProjectFolder, shouldExcludeEntry, toRealPath, toVirtualPath } from '../../utils/tools'; +import { createTool } from '../../types/tools'; +import { shouldExcludeEntry, toRealPath, toVirtualPath } from '../../utils/tools'; -export default tool({ +export default createTool({ description: 'List files and directories at the specified path.', inputSchema: list.InputSchema, outputSchema: list.OutputSchema, - - execute: async ({ path: filePath }) => { - const projectFolder = getProjectFolder(); + execute: async ({ path: filePath }, context) => { + const projectFolder = context.projectFolder; const realPath = toRealPath(filePath, projectFolder); // Get the relative path of the parent directory for naoignore matching @@ -58,7 +57,7 @@ export default tool({ }), ); - return { _version: '1', entries }; + return { _version: '1' as const, entries }; }, toModelOutput: ({ output }) => renderToModelOutput(ListOutput({ output }), output), diff --git a/apps/backend/src/agents/tools/read.ts b/apps/backend/src/agents/tools/read.ts index 1b8a05f6..d78b7e3d 100644 --- a/apps/backend/src/agents/tools/read.ts +++ b/apps/backend/src/agents/tools/read.ts @@ -1,24 +1,23 @@ import { readFile } from '@nao/shared/tools'; -import { tool } from 'ai'; import fs from 'fs/promises'; import { ReadOutput, renderToModelOutput } from '../../components/tool-outputs'; -import { getProjectFolder, toRealPath } from '../../utils/tools'; +import { createTool } from '../../types/tools'; +import { toRealPath } from '../../utils/tools'; -export default tool({ +export default createTool({ description: 'Read the contents of a file at the specified path.', inputSchema: readFile.InputSchema, outputSchema: readFile.OutputSchema, - - execute: async ({ file_path }) => { - const projectFolder = getProjectFolder(); + execute: async ({ file_path }, context) => { + const projectFolder = context.projectFolder; const realPath = toRealPath(file_path, projectFolder); const content = await fs.readFile(realPath, 'utf-8'); const numberOfTotalLines = content.split('\n').length; return { - _version: '1', + _version: '1' as const, content, numberOfTotalLines, }; diff --git a/apps/backend/src/agents/tools/search.ts b/apps/backend/src/agents/tools/search.ts index c9f2c3fd..5dc3242d 100644 --- a/apps/backend/src/agents/tools/search.ts +++ b/apps/backend/src/agents/tools/search.ts @@ -1,19 +1,18 @@ import { searchFiles } from '@nao/shared/tools'; -import { tool } from 'ai'; import fs from 'fs/promises'; import { glob } from 'glob'; import path from 'path'; import { renderToModelOutput, SearchOutput } from '../../components/tool-outputs'; -import { getProjectFolder, isWithinProjectFolder, loadNaoignorePatterns, toVirtualPath } from '../../utils/tools'; +import { createTool } from '../../types/tools'; +import { isWithinProjectFolder, loadNaoignorePatterns, toVirtualPath } from '../../utils/tools'; -export default tool({ +export default createTool({ description: 'Search for files matching a glob pattern within the project.', inputSchema: searchFiles.InputSchema, outputSchema: searchFiles.OutputSchema, - - execute: async ({ pattern }) => { - const projectFolder = getProjectFolder(); + execute: async ({ pattern }, context) => { + const projectFolder = context.projectFolder; // Sanitize pattern to prevent escaping project folder if (path.isAbsolute(pattern)) { @@ -55,7 +54,7 @@ export default tool({ }), ); - return { _version: '1', files }; + return { _version: '1' as const, files }; }, toModelOutput: ({ output }) => renderToModelOutput(SearchOutput({ output }), output), diff --git a/apps/backend/src/routes/test.ts b/apps/backend/src/routes/test.ts index 4f08f291..f7a4f1f4 100644 --- a/apps/backend/src/routes/test.ts +++ b/apps/backend/src/routes/test.ts @@ -6,6 +6,7 @@ import { authMiddleware } from '../middleware/auth'; import { ModelSelection } from '../services/agent.service'; import { TestAgentService, testAgentService } from '../services/test-agent.service'; import { llmProviderSchema } from '../types/llm'; +import { retrieveProjectById } from '../utils/chat'; const modelSelectionSchema = z.object({ provider: llmProviderSchema, @@ -43,10 +44,14 @@ export const testRoutes = async (app: App) => { try { const modelSelection = model as ModelSelection | undefined; const result = await testAgentService.runTest(projectId, prompt, modelSelection); + const project = await retrieveProjectById(projectId); let verification; if (sql) { - const { data: expectedData, columns: expectedColumns } = await executeQuery({ sql_query: sql }); + const { data: expectedData, columns: expectedColumns } = await executeQuery( + { sql_query: sql }, + { projectFolder: project.path! }, + ); const { data } = await testAgentService.runVerification( projectId, result, diff --git a/apps/backend/src/services/agent.service.ts b/apps/backend/src/services/agent.service.ts index 838bb542..a724eba8 100644 --- a/apps/backend/src/services/agent.service.ts +++ b/apps/backend/src/services/agent.service.ts @@ -18,7 +18,7 @@ import * as projectQueries from '../queries/project.queries'; import * as llmConfigQueries from '../queries/project-llm-config.queries'; import { AgentSettings } from '../types/agent-settings'; import { TokenCost, TokenUsage, UIChat, UIMessage } from '../types/chat'; -import { convertToCost, convertToTokenUsage } from '../utils/chat'; +import { convertToCost, convertToTokenUsage, retrieveProjectById } from '../utils/chat'; import { getDefaultModelId, getEnvApiKey, getEnvModelSelections, ModelSelection } from '../utils/llm'; export type { ModelSelection }; @@ -180,11 +180,18 @@ class AgentManager { }); } + // Fetch project path and run agent within project context + const project = await retrieveProjectById(this.chat.projectId); + const messages = await this._buildModelMessages(uiMessages); result = await this._agent.stream({ messages, abortSignal: this._abortController.signal, + // @ts-expect-error - experimental_context is not yet in the types + experimental_context: { + projectFolder: project.path, + }, }); writer.merge(result.toUIMessageStream({})); diff --git a/apps/backend/src/types/tools.ts b/apps/backend/src/types/tools.ts new file mode 100644 index 00000000..bdec600e --- /dev/null +++ b/apps/backend/src/types/tools.ts @@ -0,0 +1,32 @@ +import type { ToolResultOutput } from '@ai-sdk/provider-utils'; +import { tool } from 'ai'; +import type { z } from 'zod/v3'; + +type ZodSchema = z.ZodTypeAny; + +export interface ToolContext { + projectFolder: string; +} + +export interface ToolDefinition { + description: string; + inputSchema: TInput; + outputSchema: TOutput; + execute: (input: z.infer, context: ToolContext) => Promise>; + toModelOutput?: (params: { output: z.infer }) => ToolResultOutput; +} + +export function createTool( + definition: ToolDefinition, +) { + return tool({ + description: definition.description, + inputSchema: definition.inputSchema, + outputSchema: definition.outputSchema, + execute: async (input, { experimental_context }) => { + const context = experimental_context as ToolContext; + return definition.execute(input, context); + }, + ...(definition.toModelOutput && { toModelOutput: definition.toModelOutput }), + }); +} diff --git a/apps/backend/src/utils/chat.ts b/apps/backend/src/utils/chat.ts index d00cddf1..01973fd8 100644 --- a/apps/backend/src/utils/chat.ts +++ b/apps/backend/src/utils/chat.ts @@ -1,6 +1,8 @@ import { LanguageModelUsage } from 'ai'; import { LLM_PROVIDERS } from '../agents/providers'; +import * as projectQueries from '../queries/project.queries'; +import { DBProject } from '../queries/project-slack-config.queries'; import { TokenCost, TokenUsage } from '../types/chat'; import { LlmProvider } from '../types/llm'; @@ -51,3 +53,14 @@ export const extractLastTextFromMessage = (message: { parts: { type: string; tex } return ''; }; + +export const retrieveProjectById = async (projectId: string): Promise => { + const project = await projectQueries.getProjectById(projectId); + if (!project) { + throw new Error(`Project not found: ${projectId}`); + } + if (!project.path) { + throw new Error(`Project path not configured: ${projectId}`); + } + return project; +}; diff --git a/apps/backend/src/utils/tools.ts b/apps/backend/src/utils/tools.ts index e8ee0c88..980e1635 100644 --- a/apps/backend/src/utils/tools.ts +++ b/apps/backend/src/utils/tools.ts @@ -2,8 +2,6 @@ import fs from 'fs'; import { minimatch } from 'minimatch'; import path from 'path'; -import { env } from '../env'; - const MCP_TOOL_SEPARATOR = '__'; /** @@ -146,18 +144,6 @@ export const shouldExcludeEntry = (entryName: string, parentPath: string, projec return isIgnoredByNaoignore(relativePath, projectFolder); }; -/** - * Gets the resolved project folder path from the NAO_DEFAULT_PROJECT_PATH environment variable. - * @throws Error if NAO_DEFAULT_PROJECT_PATH is not set - */ -export const getProjectFolder = (): string => { - const projectFolder = env.NAO_DEFAULT_PROJECT_PATH; - if (!projectFolder) { - throw new Error('NAO_DEFAULT_PROJECT_PATH environment variable is not set'); - } - return path.resolve(projectFolder); -}; - /** * Checks if a given path is within the project folder, not in an excluded directory, * and not ignored by .naoignore.