Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions apps/backend/src/agents/tools/execute-python.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { executePython as schemas } from '@nao/shared/tools';
import { tool } from 'ai';
import fs from 'fs';
import path from 'path';

import { getProjectFolder, isWithinProjectFolder, toVirtualPath } from '../../utils/tools';
import { createTool } from '../../types/tools';
import { isWithinProjectFolder, toVirtualPath } from '../../utils/tools';

// @pydantic/monty uses native bindings that aren't available on all platforms
// (e.g. no Linux ARM64 binary). Load lazily so the server can still start.
Expand All @@ -21,14 +21,17 @@ const RESOURCE_LIMITS = {
maxRecursionDepth: 500,
};

export async function executePython({ code, inputs }: schemas.Input): Promise<schemas.Output> {
async function executePython(
{ code, inputs }: schemas.Input,
{ projectFolder }: { projectFolder: string },
): Promise<schemas.Output> {
if (!montyModule) {
throw new Error('Python execution is not available on this platform');
}

const { Monty, MontyRuntimeError, MontySnapshot, MontySyntaxError, MontyTypingError } = montyModule;
const inputNames = inputs ? Object.keys(inputs) : [];
const virtualFS = createVirtualFS();
const virtualFS = createVirtualFS(projectFolder);

let monty: InstanceType<typeof Monty>;
try {
Expand Down Expand Up @@ -123,15 +126,10 @@ function findAllFiles(dir: string, projectFolder: string): schemas.VirtualFile[]
return files;
}

function loadProjectFiles(): schemas.VirtualFile[] {
const projectFolder = getProjectFolder();
return findAllFiles(projectFolder, projectFolder);
}

function createVirtualFS(): Map<string, string> {
function createVirtualFS(projectFolder: string): Map<string, string> {
const vfs = new Map<string, string>();

const projectFiles = loadProjectFiles();
const projectFiles = findAllFiles(projectFolder, projectFolder);
for (const file of projectFiles) {
vfs.set(file.path, file.content);
}
Expand All @@ -145,10 +143,12 @@ const EXTERNAL_FUNCTION_NAMES = schemas.EXTERNAL_FUNCTIONS.map((f) => f.name);
export const isPythonAvailable = montyModule !== null;

export default montyModule
? tool({
? createTool({
description: schemas.description,
inputSchema: schemas.inputSchema,
outputSchema: schemas.outputSchema,
execute: executePython,
execute: async (input, context) => {
return executePython(input, { projectFolder: context.projectFolder });
},
})
: null;
30 changes: 15 additions & 15 deletions apps/backend/src/agents/tools/execute-sql.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import type { executeSql } from '@nao/shared/tools';
import { executeSql as schemas } from '@nao/shared/tools';
import { tool } from 'ai';

import { ExecuteSqlOutput, renderToModelOutput } from '../../components/tool-outputs';
import { env } from '../../env';
import { getProjectFolder } from '../../utils/tools';
import { createTool, type ToolContext } from '../../types/tools';

export default tool<executeSql.Input, executeSql.Output>({
description:
'Execute a SQL query against the connected database and return the results. If multiple databases are configured, specify the database_id.',
inputSchema: schemas.InputSchema,
outputSchema: schemas.OutputSchema,

execute: executeQuery,

toModelOutput: ({ output }) => renderToModelOutput(ExecuteSqlOutput({ output }), output),
});

export async function executeQuery({ sql_query, database_id }: executeSql.Input): Promise<executeSql.Output> {
const naoProjectFolder = getProjectFolder();
export async function executeQuery(
{ sql_query, database_id }: executeSql.Input,
context: ToolContext,
): Promise<executeSql.Output> {
const naoProjectFolder = context.projectFolder;

const response = await fetch(`http://localhost:${env.FASTAPI_PORT}/execute_sql`, {
method: 'POST',
Expand All @@ -44,3 +35,12 @@ export async function executeQuery({ sql_query, database_id }: executeSql.Input)
id: `query_${crypto.randomUUID().slice(0, 8)}`,
};
}

export default createTool({
description:
'Execute a SQL query against the connected database and return the results. If multiple databases are configured, specify the database_id.',
inputSchema: schemas.InputSchema,
outputSchema: schemas.OutputSchema,
execute: executeQuery,
toModelOutput: ({ output }) => renderToModelOutput(ExecuteSqlOutput({ output }), output),
});
20 changes: 8 additions & 12 deletions apps/backend/src/agents/tools/grep.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import { grep } from '@nao/shared/tools';
import { tool } from 'ai';
import { spawn } from 'child_process';
import fs from 'fs';
import path from 'path';

import { GrepOutput, renderToModelOutput } from '../../components/tool-outputs';
import {
getProjectFolder,
isWithinProjectFolder,
loadNaoignorePatterns,
toRealPath,
toVirtualPath,
} from '../../utils/tools';
import { createTool } from '../../types/tools';
import { isWithinProjectFolder, loadNaoignorePatterns, toRealPath, toVirtualPath } from '../../utils/tools';

/**
* Gets the path to the ripgrep binary.
Expand Down Expand Up @@ -53,13 +47,15 @@ interface RipgrepMatch {
context_after?: string[];
}

export default tool<grep.Input, grep.Output>({
export default createTool({
description: 'Search for text patterns in files using ripgrep. Supports regex patterns and respects .gitignore.',
inputSchema: grep.InputSchema,
outputSchema: grep.OutputSchema,

execute: async ({ pattern, path: searchPath, glob, case_insensitive, context_lines, max_results = 100 }) => {
const projectFolder = getProjectFolder();
execute: async (
{ pattern, path: searchPath, glob, case_insensitive, context_lines, max_results = 100 },
context,
) => {
const projectFolder = context.projectFolder;
const rgPath = getRipgrepPath();

// Determine the search path
Expand Down
13 changes: 6 additions & 7 deletions apps/backend/src/agents/tools/list.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import { list } from '@nao/shared/tools';
import { tool } from 'ai';
import fs from 'fs/promises';
import path from 'path';

import { ListOutput, renderToModelOutput } from '../../components/tool-outputs';
import { getProjectFolder, shouldExcludeEntry, toRealPath, toVirtualPath } from '../../utils/tools';
import { createTool } from '../../types/tools';
import { shouldExcludeEntry, toRealPath, toVirtualPath } from '../../utils/tools';

export default tool<list.Input, list.Output>({
export default createTool({
description: 'List files and directories at the specified path.',
inputSchema: list.InputSchema,
outputSchema: list.OutputSchema,

execute: async ({ path: filePath }) => {
const projectFolder = getProjectFolder();
execute: async ({ path: filePath }, context) => {
const projectFolder = context.projectFolder;
const realPath = toRealPath(filePath, projectFolder);

// Get the relative path of the parent directory for naoignore matching
Expand Down Expand Up @@ -58,7 +57,7 @@ export default tool<list.Input, list.Output>({
}),
);

return { _version: '1', entries };
return { _version: '1' as const, entries };
},

toModelOutput: ({ output }) => renderToModelOutput(ListOutput({ output }), output),
Expand Down
13 changes: 6 additions & 7 deletions apps/backend/src/agents/tools/read.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import { readFile } from '@nao/shared/tools';
import { tool } from 'ai';
import fs from 'fs/promises';

import { ReadOutput, renderToModelOutput } from '../../components/tool-outputs';
import { getProjectFolder, toRealPath } from '../../utils/tools';
import { createTool } from '../../types/tools';
import { toRealPath } from '../../utils/tools';

export default tool<readFile.Input, readFile.Output>({
export default createTool({
description: 'Read the contents of a file at the specified path.',
inputSchema: readFile.InputSchema,
outputSchema: readFile.OutputSchema,

execute: async ({ file_path }) => {
const projectFolder = getProjectFolder();
execute: async ({ file_path }, context) => {
const projectFolder = context.projectFolder;
const realPath = toRealPath(file_path, projectFolder);

const content = await fs.readFile(realPath, 'utf-8');
const numberOfTotalLines = content.split('\n').length;

return {
_version: '1',
_version: '1' as const,
content,
numberOfTotalLines,
};
Expand Down
13 changes: 6 additions & 7 deletions apps/backend/src/agents/tools/search.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import { searchFiles } from '@nao/shared/tools';
import { tool } from 'ai';
import fs from 'fs/promises';
import { glob } from 'glob';
import path from 'path';

import { renderToModelOutput, SearchOutput } from '../../components/tool-outputs';
import { getProjectFolder, isWithinProjectFolder, loadNaoignorePatterns, toVirtualPath } from '../../utils/tools';
import { createTool } from '../../types/tools';
import { isWithinProjectFolder, loadNaoignorePatterns, toVirtualPath } from '../../utils/tools';

export default tool<searchFiles.Input, searchFiles.Output>({
export default createTool({
description: 'Search for files matching a glob pattern within the project.',
inputSchema: searchFiles.InputSchema,
outputSchema: searchFiles.OutputSchema,

execute: async ({ pattern }) => {
const projectFolder = getProjectFolder();
execute: async ({ pattern }, context) => {
const projectFolder = context.projectFolder;

// Sanitize pattern to prevent escaping project folder
if (path.isAbsolute(pattern)) {
Expand Down Expand Up @@ -55,7 +54,7 @@ export default tool<searchFiles.Input, searchFiles.Output>({
}),
);

return { _version: '1', files };
return { _version: '1' as const, files };
},

toModelOutput: ({ output }) => renderToModelOutput(SearchOutput({ output }), output),
Expand Down
7 changes: 6 additions & 1 deletion apps/backend/src/routes/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { authMiddleware } from '../middleware/auth';
import { ModelSelection } from '../services/agent.service';
import { TestAgentService, testAgentService } from '../services/test-agent.service';
import { llmProviderSchema } from '../types/llm';
import { retrieveProjectById } from '../utils/chat';

const modelSelectionSchema = z.object({
provider: llmProviderSchema,
Expand Down Expand Up @@ -43,10 +44,14 @@ export const testRoutes = async (app: App) => {
try {
const modelSelection = model as ModelSelection | undefined;
const result = await testAgentService.runTest(projectId, prompt, modelSelection);
const project = await retrieveProjectById(projectId);

let verification;
if (sql) {
const { data: expectedData, columns: expectedColumns } = await executeQuery({ sql_query: sql });
const { data: expectedData, columns: expectedColumns } = await executeQuery(
{ sql_query: sql },
{ projectFolder: project.path! },
);
const { data } = await testAgentService.runVerification(
projectId,
result,
Expand Down
9 changes: 8 additions & 1 deletion apps/backend/src/services/agent.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import * as projectQueries from '../queries/project.queries';
import * as llmConfigQueries from '../queries/project-llm-config.queries';
import { AgentSettings } from '../types/agent-settings';
import { TokenCost, TokenUsage, UIChat, UIMessage } from '../types/chat';
import { convertToCost, convertToTokenUsage } from '../utils/chat';
import { convertToCost, convertToTokenUsage, retrieveProjectById } from '../utils/chat';
import { getDefaultModelId, getEnvApiKey, getEnvModelSelections, ModelSelection } from '../utils/llm';

export type { ModelSelection };
Expand Down Expand Up @@ -180,11 +180,18 @@ class AgentManager {
});
}

// Fetch project path and run agent within project context
const project = await retrieveProjectById(this.chat.projectId);

const messages = await this._buildModelMessages(uiMessages);

result = await this._agent.stream({
messages,
abortSignal: this._abortController.signal,
// @ts-expect-error - experimental_context is not yet in the types
experimental_context: {
projectFolder: project.path,
},
});

writer.merge(result.toUIMessageStream({}));
Expand Down
32 changes: 32 additions & 0 deletions apps/backend/src/types/tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import type { ToolResultOutput } from '@ai-sdk/provider-utils';
import { tool } from 'ai';
import type { z } from 'zod/v3';

type ZodSchema = z.ZodTypeAny;

export interface ToolContext {
projectFolder: string;
}

export interface ToolDefinition<TInput extends ZodSchema, TOutput extends ZodSchema> {
description: string;
inputSchema: TInput;
outputSchema: TOutput;
execute: (input: z.infer<TInput>, context: ToolContext) => Promise<z.infer<TOutput>>;
toModelOutput?: (params: { output: z.infer<TOutput> }) => ToolResultOutput;
}

export function createTool<TInput extends ZodSchema, TOutput extends ZodSchema>(
definition: ToolDefinition<TInput, TOutput>,
) {
return tool({
description: definition.description,
inputSchema: definition.inputSchema,
outputSchema: definition.outputSchema,
execute: async (input, { experimental_context }) => {
const context = experimental_context as ToolContext;
return definition.execute(input, context);
},
...(definition.toModelOutput && { toModelOutput: definition.toModelOutput }),
});
}
13 changes: 13 additions & 0 deletions apps/backend/src/utils/chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { LanguageModelUsage } from 'ai';

import { LLM_PROVIDERS } from '../agents/providers';
import * as projectQueries from '../queries/project.queries';
import { DBProject } from '../queries/project-slack-config.queries';
import { TokenCost, TokenUsage } from '../types/chat';
import { LlmProvider } from '../types/llm';

Expand Down Expand Up @@ -51,3 +53,14 @@ export const extractLastTextFromMessage = (message: { parts: { type: string; tex
}
return '';
};

export const retrieveProjectById = async (projectId: string): Promise<DBProject> => {
const project = await projectQueries.getProjectById(projectId);
if (!project) {
throw new Error(`Project not found: ${projectId}`);
}
if (!project.path) {
throw new Error(`Project path not configured: ${projectId}`);
}
return project;
};
14 changes: 0 additions & 14 deletions apps/backend/src/utils/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ import fs from 'fs';
import { minimatch } from 'minimatch';
import path from 'path';

import { env } from '../env';

const MCP_TOOL_SEPARATOR = '__';

/**
Expand Down Expand Up @@ -146,18 +144,6 @@ export const shouldExcludeEntry = (entryName: string, parentPath: string, projec
return isIgnoredByNaoignore(relativePath, projectFolder);
};

/**
* Gets the resolved project folder path from the NAO_DEFAULT_PROJECT_PATH environment variable.
* @throws Error if NAO_DEFAULT_PROJECT_PATH is not set
*/
export const getProjectFolder = (): string => {
const projectFolder = env.NAO_DEFAULT_PROJECT_PATH;
if (!projectFolder) {
throw new Error('NAO_DEFAULT_PROJECT_PATH environment variable is not set');
}
return path.resolve(projectFolder);
};

/**
* Checks if a given path is within the project folder, not in an excluded directory,
* and not ignored by .naoignore.
Expand Down
Loading