diff --git a/src/config.ts b/src/config.ts index 8e0093ac..7c52d662 100644 --- a/src/config.ts +++ b/src/config.ts @@ -64,6 +64,7 @@ export class Config { serverLogDirectory: string; boundedContext: BoundedContext; tableauServerVersionCheckIntervalInHours: number; + resultSizeLimitKb: number | null; oauth: { enabled: boolean; issuer: string; @@ -121,6 +122,7 @@ export class Config { INCLUDE_DATASOURCE_IDS: includeDatasourceIds, INCLUDE_WORKBOOK_IDS: includeWorkbookIds, TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: tableauServerVersionCheckIntervalInHours, + RESULT_SIZE_LIMIT_KB: resultSizeLimitKb, DANGEROUSLY_DISABLE_OAUTH: disableOauth, OAUTH_ISSUER: oauthIssuer, OAUTH_LOCK_SITE: oauthLockSite, @@ -190,6 +192,13 @@ export class Config { }, ); + this.resultSizeLimitKb = resultSizeLimitKb + ? parseNumber(resultSizeLimitKb, { + defaultValue: 1024, + minValue: 0, + }) + : null; + const disableOauthOverride = disableOauth === 'true'; this.oauth = { enabled: disableOauthOverride ? false : !!oauthIssuer, diff --git a/src/index.ts b/src/index.ts index be655454..be723b7b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,6 +9,8 @@ import { Server, serverName, serverVersion } from './server.js'; import { startExpressServer } from './server/express.js'; import { getExceptionMessage } from './utils/getExceptionMessage.js'; +let serverUrl: string | undefined; + async function startServer(): Promise { dotenv.config(); const config = getConfig(); @@ -33,6 +35,7 @@ async function startServer(): Promise { } case 'http': { const { url } = await startExpressServer({ basePath: serverName, config, logLevel }); + serverUrl = url; if (!config.oauth.enabled) { console.warn( @@ -57,3 +60,7 @@ startServer().catch((error) => { writeToStderr(`Fatal error when starting the server: ${getExceptionMessage(error)}`); process.exit(1); }); + +export function getServerUrl(): string | undefined { + return serverUrl; +} diff --git a/src/scripts/createClaudeMcpBundleManifest.ts b/src/scripts/createClaudeMcpBundleManifest.ts index ef48d950..63d6e1ba 100644 --- a/src/scripts/createClaudeMcpBundleManifest.ts +++ b/src/scripts/createClaudeMcpBundleManifest.ts @@ -348,6 +348,14 @@ const envVars = { required: false, sensitive: false, }, + RESULT_SIZE_LIMIT_KB: { + includeInUserConfig: false, + type: 'number', + title: 'Result Size Limit (kb)', + description: 'The maximum size of the result in kilobytes.', + required: false, + sensitive: false, + }, DANGEROUSLY_DISABLE_OAUTH: { includeInUserConfig: false, type: 'boolean', diff --git a/src/server/express.ts b/src/server/express.ts index 7b52548c..7357a457 100644 --- a/src/server/express.ts +++ b/src/server/express.ts @@ -2,7 +2,7 @@ import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/ import { isInitializeRequest, LoggingLevel } from '@modelcontextprotocol/sdk/types.js'; import cors from 'cors'; import express, { Request, RequestHandler, Response } from 'express'; -import fs, { existsSync } from 'fs'; +import fs, { existsSync, unlinkSync } from 'fs'; import http from 'http'; import https from 'https'; @@ -10,7 +10,12 @@ import { Config } from '../config.js'; import { setLogLevel } from '../logging/log.js'; import { Server } from '../server.js'; import { createSession, getSession, Session } from '../sessions.js'; -import { handlePingRequest, validateProtocolVersion } from './middleware.js'; +import { getLargeResultFilePath } from './getLargeResult.js'; +import { + getRateLimitMiddleware, + handlePingRequest, + validateProtocolVersion, +} from './middleware.js'; import { getTableauAuthInfo } from './oauth/getTableauAuthInfo.js'; import { OAuthProvider } from './oauth/provider.js'; import { TableauAuthInfo } from './oauth/schemas.js'; @@ -52,7 +57,11 @@ export async function startExpressServer({ app.set('trust proxy', config.trustProxyConfig); } - const middleware: Array = [handlePingRequest]; + const middleware: Array = [ + handlePingRequest, + getRateLimitMiddleware({ windowMs: 60000, maxRequests: 30, responseFormat: 'mcp' }), + ]; + if (config.oauth.enabled) { const oauthProvider = new OAuthProvider(); oauthProvider.setupRoutes(app); @@ -73,6 +82,36 @@ export async function startExpressServer({ config.disableSessionManagement ? methodNotAllowed : handleSessionRequest, ); + app.get( + `${path}/results/:filename`, + getRateLimitMiddleware({ windowMs: 60000, maxRequests: 5, responseFormat: 'html' }), + (req, res) => { + const filename = req.params.filename; + + const result = getLargeResultFilePath(filename); + if (result.isErr()) { + res.status(result.error.status).send(result.error.message); + return; + } + + const { fullFilePath } = result.value; + res.download(fullFilePath, `${filename}.txt`, (err) => { + if (err) { + // Don't delete the file if there was an error sending it + console.error(`Error sending file ${fullFilePath}:`, err); + return; + } + + // File was successfully sent, it is now safe to delete + try { + unlinkSync(fullFilePath); + } catch (deleteErr) { + console.error(`Error deleting file ${fullFilePath}:`, deleteErr); + } + }); + }, + ); + const useSsl = !!(config.sslKey && config.sslCert); if (!useSsl) { return new Promise((resolve) => { diff --git a/src/server/getLargeResult.ts b/src/server/getLargeResult.ts new file mode 100644 index 00000000..51d6b04d --- /dev/null +++ b/src/server/getLargeResult.ts @@ -0,0 +1,22 @@ +import { existsSync } from 'fs'; +import { join } from 'path'; +import { Err, Ok, Result } from 'ts-results-es'; + +import { getDirname } from '../utils/getDirname'; + +const uuidV4Regex = /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i; + +export function getLargeResultFilePath( + fileResourceId: string, +): Result<{ fullFilePath: string }, { status: number; message: string }> { + if (!uuidV4Regex.test(fileResourceId)) { + return Err({ status: 400, message: 'Invalid file resource ID' }); + } + + const filePath = join(getDirname(), 'results', `${fileResourceId}.txt`); + if (!existsSync(filePath)) { + return Err({ status: 404, message: 'Result not found' }); + } + + return Ok({ fullFilePath: filePath }); +} diff --git a/src/server/middleware.ts b/src/server/middleware.ts index e8650133..6a2f4dcf 100644 --- a/src/server/middleware.ts +++ b/src/server/middleware.ts @@ -1,5 +1,5 @@ import { PingRequestSchema } from '@modelcontextprotocol/sdk/types.js'; -import { NextFunction, Request, Response } from 'express'; +import { NextFunction, Request, RequestHandler, Response } from 'express'; /** * Validate MCP protocol version @@ -44,3 +44,56 @@ export function handlePingRequest(req: Request, res: Response, next: NextFunctio } next(); } + +const requestCounts = new Map(); + +export function getRateLimitMiddleware({ + windowMs, + maxRequests, + responseFormat, +}: { + windowMs: number; + maxRequests: number; + responseFormat: 'mcp' | 'html'; +}): RequestHandler { + return (req: Request, res: Response, next: NextFunction): void => { + const key = req.ip || 'unknown'; + const now = Date.now(); + + let rateData = requestCounts.get(key); + if (!rateData || now > rateData.resetTime) { + rateData = { count: 0, resetTime: now + windowMs }; + requestCounts.set(key, rateData); + } + + if (rateData.count >= maxRequests) { + const retryAfter = Math.ceil((rateData.resetTime - now) / 1000); + if (responseFormat === 'mcp') { + res.status(429).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Too many requests', + data: { retryAfter }, + }, + }); + } else { + res.status(429).set('Retry-After', retryAfter.toString()).send(` + + + Too Many Requests + + +

Too Many Requests

+

You're doing that too often! Try again in ${retryAfter} seconds.

+ + + `); + } + return; + } + + rateData.count++; + next(); + }; +} diff --git a/src/tools/tool.ts b/src/tools/tool.ts index 3984eb9d..7269f548 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -2,13 +2,19 @@ import { AuthInfo } from '@modelcontextprotocol/sdk/server/auth/types.js'; import { ToolCallback } from '@modelcontextprotocol/sdk/server/mcp.js'; import { CallToolResult, RequestId, ToolAnnotations } from '@modelcontextprotocol/sdk/types.js'; import { ZodiosError } from '@zodios/core'; +import { randomUUID } from 'crypto'; +import { existsSync, mkdirSync, writeFileSync } from 'fs'; +import { join } from 'path'; import { Result } from 'ts-results-es'; import { z, ZodRawShape, ZodTypeAny } from 'zod'; import { fromError, isZodErrorLike } from 'zod-validation-error'; +import { getConfig } from '../config.js'; +import { getServerUrl } from '../index.js'; import { getToolLogMessage, log } from '../logging/log.js'; import { Server } from '../server.js'; import { tableauAuthInfoSchema } from '../server/oauth/schemas.js'; +import { getDirname } from '../utils/getDirname.js'; import { getExceptionMessage } from '../utils/getExceptionMessage.js'; import { Provider, TypeOrProvider } from '../utils/provider.js'; import { ToolName } from './toolName.js'; @@ -57,6 +63,9 @@ export type ToolParams = { // The implementation of the tool itself callback: TypeOrProvider>; + + // Whether the result size of the tool is unlimited + isResultSizeUnlimited?: TypeOrProvider; }; /** @@ -103,6 +112,7 @@ export class Tool { annotations: TypeOrProvider; argsValidator?: TypeOrProvider>; callback: TypeOrProvider>; + isResultSizeUnlimited?: TypeOrProvider; constructor({ server, @@ -112,6 +122,7 @@ export class Tool { annotations, argsValidator, callback, + isResultSizeUnlimited, }: ToolParams) { this.server = server; this.name = name; @@ -120,6 +131,7 @@ export class Tool { this.annotations = annotations; this.argsValidator = argsValidator; this.callback = callback; + this.isResultSizeUnlimited = isResultSizeUnlimited; } logInvocation({ @@ -194,19 +206,31 @@ export class Tool { }; } + const isResultSizeUnlimited = await Provider.from(this.isResultSizeUnlimited); + const rowCount = Array.isArray(constrainedResult.result) + ? constrainedResult.result.length + : undefined; if (getSuccessResult) { - return getSuccessResult(constrainedResult.result); + const successResult = getSuccessResult(constrainedResult.result); + return isResultSizeUnlimited + ? successResult + : getSizeLimitedResult({ + result: getSuccessResult(constrainedResult.result), + rowCount, + }); } - return { + const successResult: CallToolResult = { isError: false, - content: [ - { - type: 'text', - text: JSON.stringify(constrainedResult.result), - }, - ], + content: [{ type: 'text', text: JSON.stringify(constrainedResult.result) }], }; + + return isResultSizeUnlimited + ? successResult + : getSizeLimitedResult({ + result: successResult, + rowCount, + }); } if (result.error instanceof ZodiosError) { @@ -232,6 +256,67 @@ export class Tool { } } +function getSizeLimitedResult({ + result, + rowCount, +}: { + result: CallToolResult; + rowCount: number | undefined; +}): CallToolResult { + const { resultSizeLimitKb, transport } = getConfig(); + if (resultSizeLimitKb === null) { + return result; + } + + if (result.content.length > 0 && result.content[0].type === 'text') { + const text = result.content[0].text; + const bytes = new TextEncoder().encode(text); + const fileSizeKb = Math.ceil(bytes.length / 1024); + + if (fileSizeKb > resultSizeLimitKb) { + const resultsDirectory = join(getDirname(), 'results'); + if (!existsSync(resultsDirectory)) { + mkdirSync(resultsDirectory, { recursive: true }); + } + + const filename = randomUUID(); + const fullFilePath = join(resultsDirectory, `${filename}.txt`); + writeFileSync(fullFilePath, text); + + const largeResult = { + status: 'size_limit_exceeded', + actual_size_kb: fileSizeKb, + file_resource_id: filename, + ...(rowCount !== undefined ? { row_count: rowCount } : {}), + ...(transport === 'http' + ? { file_resource_url: `${getServerUrl()}/results/${filename}` } + : { file_resource_path: fullFilePath }), + instruction: [ + 'The result is too large for the context window.', + 'Consider refining your original query with more specific filters (LIMIT, WHERE) to reduce the volume.', + 'You can also access the full results with a one-time request:', + ' 1) Use the get-large-result tool to retrieve them.', + transport === 'http' + ? " 2) Download them from the URL specified by the 'file_resource_url' field." + : " 2) View them in the file specified by the 'file_resource_path' field.", + 'Once accessed, the file will be deleted from the server.', + ].join('\n'), + }; + + return { + isError: true, + content: [ + { + type: 'text', + text: JSON.stringify(largeResult), + }, + ], + }; + } + } + return result; +} + function getErrorResult(requestId: RequestId, error: unknown): CallToolResult { if (error instanceof ZodiosError && isZodErrorLike(error.cause)) { // Schema validation errors on otherwise successful API calls will not return an "error" result to the MCP client. diff --git a/src/tools/toolName.ts b/src/tools/toolName.ts index 137e95a4..fd004a43 100644 --- a/src/tools/toolName.ts +++ b/src/tools/toolName.ts @@ -15,6 +15,7 @@ export const toolNames = [ 'generate-pulse-metric-value-insight-bundle', 'generate-pulse-insight-brief', 'search-content', + 'get-large-result', ] as const; export type ToolName = (typeof toolNames)[number]; @@ -24,6 +25,7 @@ export const toolGroupNames = [ 'view', 'pulse', 'content-exploration', + 'utility', ] as const; export type ToolGroupName = (typeof toolGroupNames)[number]; @@ -41,6 +43,7 @@ export const toolGroups = { 'generate-pulse-insight-brief', ], 'content-exploration': ['search-content'], + utility: ['get-large-result'], } as const satisfies Record>; export function isToolName(value: unknown): value is ToolName { diff --git a/src/tools/tools.ts b/src/tools/tools.ts index 9dbc212c..c5c86f89 100644 --- a/src/tools/tools.ts +++ b/src/tools/tools.ts @@ -9,6 +9,7 @@ import { getListPulseMetricsFromMetricDefinitionIdTool } from './pulse/listMetri import { getListPulseMetricsFromMetricIdsTool } from './pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.js'; import { getListPulseMetricSubscriptionsTool } from './pulse/listMetricSubscriptions/listPulseMetricSubscriptions.js'; import { getQueryDatasourceTool } from './queryDatasource/queryDatasource.js'; +import { getLargeResultTool } from './utility/getLargeResult.js'; import { getGetViewDataTool } from './views/getViewData.js'; import { getGetViewImageTool } from './views/getViewImage.js'; import { getListViewsTool } from './views/listViews.js'; @@ -32,4 +33,5 @@ export const toolFactories = [ getListWorkbooksTool, getListViewsTool, getSearchContentTool, + getLargeResultTool, ]; diff --git a/src/tools/utility/getLargeResult.ts b/src/tools/utility/getLargeResult.ts new file mode 100644 index 00000000..0b04dff1 --- /dev/null +++ b/src/tools/utility/getLargeResult.ts @@ -0,0 +1,56 @@ +import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import { readFileSync, unlinkSync } from 'fs'; +import { Err, Ok } from 'ts-results-es'; +import { z } from 'zod'; + +import { Server } from '../../server.js'; +import { getLargeResultFilePath } from '../../server/getLargeResult.js'; +import { getExceptionMessage } from '../../utils/getExceptionMessage.js'; +import { Tool } from '../tool.js'; + +const paramsSchema = { + fileResourceId: z.string(), +}; + +export const getLargeResultTool = (server: Server): Tool => { + const largeResultTool = new Tool({ + server, + name: 'get-large-result', + description: + 'This tool retrieves a large result that was previously generated by another tool. The result is stored on the server and can be accessed by providing the file resource ID.', + paramsSchema, + isResultSizeUnlimited: true, + annotations: { + title: 'Get Large Result', + readOnlyHint: true, + destructiveHint: true, // The result is deleted from the server after it is accessed. + openWorldHint: false, + }, + callback: async ({ fileResourceId }, { requestId, authInfo }): Promise => { + return await largeResultTool.logAndExecute({ + requestId, + authInfo, + args: { fileResourceId }, + callback: async () => { + const result = getLargeResultFilePath(fileResourceId); + if (result.isErr()) { + return Err(result.error.message); + } + + try { + const contents = readFileSync(result.value.fullFilePath); + const text = contents.toString('utf8'); + unlinkSync(result.value.fullFilePath); + return Ok(JSON.parse(text)); + } catch (error) { + return Err(getExceptionMessage(error)); + } + }, + constrainSuccessResult: (result) => ({ type: 'success', result }), + getErrorText: (error) => error, + }); + }, + }); + + return largeResultTool; +}; diff --git a/types/process-env.d.ts b/types/process-env.d.ts index 4852b5c5..4e1235f7 100644 --- a/types/process-env.d.ts +++ b/types/process-env.d.ts @@ -38,6 +38,7 @@ export interface ProcessEnvEx { INCLUDE_DATASOURCE_IDS: string | undefined; INCLUDE_WORKBOOK_IDS: string | undefined; TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: string | undefined; + RESULT_SIZE_LIMIT_KB: string | undefined; DANGEROUSLY_DISABLE_OAUTH: string | undefined; OAUTH_ISSUER: string | undefined; OAUTH_JWE_PRIVATE_KEY: string | undefined;