diff --git a/package-lock.json b/package-lock.json index 0378ea39..8f8a1016 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@tableau/mcp-server", - "version": "1.14.6", + "version": "1.14.7", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@tableau/mcp-server", - "version": "1.14.6", + "version": "1.14.7", "license": "Apache-2.0", "dependencies": { "@modelcontextprotocol/sdk": "^1.25.2", diff --git a/package.json b/package.json index c345912b..7eab637d 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@tableau/mcp-server", "description": "An MCP server for Tableau, providing a suite of tools that will make it easier for developers to build AI applications that integrate with Tableau.", - "version": "1.14.6", + "version": "1.14.7", "repository": { "type": "git", "url": "git+https://github.com/tableau/tableau-mcp.git" diff --git a/src/config.test.ts b/src/config.test.ts index 3a16e59a..f30f4dea 100644 --- a/src/config.test.ts +++ b/src/config.test.ts @@ -1,5 +1,3 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest'; - import { exportedForTesting, ONE_HOUR_IN_MS, TEN_MINUTES_IN_MS } from './config.js'; import { stubDefaultEnvVars } from './testShared.js'; @@ -128,30 +126,6 @@ describe('Config', () => { expect(config.maxRequestTimeoutMs).toBe(TEN_MINUTES_IN_MS); }); - it('should set disableQueryDatasourceValidationRequests to false by default', () => { - const config = new Config(); - expect(config.disableQueryDatasourceValidationRequests).toBe(false); - }); - - it('should set disableQueryDatasourceValidationRequests to true when specified', () => { - vi.stubEnv('DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS', 'true'); - - const config = new Config(); - expect(config.disableQueryDatasourceValidationRequests).toBe(true); - }); - - it('should set disableMetadataApiRequests to false by default', () => { - const config = new Config(); - expect(config.disableMetadataApiRequests).toBe(false); - }); - - it('should set disableMetadataApiRequests to true when specified', () => { - vi.stubEnv('DISABLE_METADATA_API_REQUESTS', 'true'); - - const config = new Config(); - expect(config.disableMetadataApiRequests).toBe(true); - }); - it('should set disableSessionManagement to false by default', () => { const config = new Config(); expect(config.disableSessionManagement).toBe(false); @@ -189,68 +163,28 @@ describe('Config', () => { expect(config.tableauServerVersionCheckIntervalInHours).toBe(2); }); - describe('Tool filtering', () => { - it('should set empty arrays for includeTools and excludeTools when not specified', () => { - const config = new Config(); - expect(config.includeTools).toEqual([]); - expect(config.excludeTools).toEqual([]); - }); - - it('should parse INCLUDE_TOOLS into an array of valid tool names', () => { - vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,get-datasource-metadata'); - - const config = new Config(); - expect(config.includeTools).toEqual(['query-datasource', 'get-datasource-metadata']); - }); - - it('should parse INCLUDE_TOOLS into an array of valid tool names when tool group names are used', () => { - vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,workbook'); - - const config = new Config(); - expect(config.includeTools).toEqual(['query-datasource', 'list-workbooks', 'get-workbook']); - }); - - it('should parse EXCLUDE_TOOLS into an array of valid tool names', () => { - vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource'); - - const config = new Config(); - expect(config.excludeTools).toEqual(['query-datasource']); - }); - - it('should parse EXCLUDE_TOOLS into an array of valid tool names when tool group names are used', () => { - vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource,workbook'); - - const config = new Config(); - expect(config.excludeTools).toEqual(['query-datasource', 'list-workbooks', 'get-workbook']); - }); - - it('should filter out invalid tool names from INCLUDE_TOOLS', () => { - vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,order-hamburgers'); - - const config = new Config(); - expect(config.includeTools).toEqual(['query-datasource']); - }); - - it('should filter out invalid tool names from EXCLUDE_TOOLS', () => { - vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource,order-hamburgers'); + it('should set mcpSiteSettingsCheckIntervalInMinutes to default when not specified', () => { + const config = new Config(); + expect(config.mcpSiteSettingsCheckIntervalInMinutes).toBe(10); + }); - const config = new Config(); - expect(config.excludeTools).toEqual(['query-datasource']); - }); + it('should set mcpSiteSettingsCheckIntervalInMinutes to the specified value when specified', () => { + vi.stubEnv('MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES', '2'); - it('should throw error when both INCLUDE_TOOLS and EXCLUDE_TOOLS are specified', () => { - vi.stubEnv('INCLUDE_TOOLS', 'query-datasource'); - vi.stubEnv('EXCLUDE_TOOLS', 'get-datasource-metadata'); + const config = new Config(); + expect(config.mcpSiteSettingsCheckIntervalInMinutes).toBe(2); + }); - expect(() => new Config()).toThrow('Cannot include and exclude tools simultaneously'); - }); + it('should set enableMcpSiteSettings to false by default', () => { + const config = new Config(); + expect(config.enableMcpSiteSettings).toBe(false); + }); - it('should throw error when both INCLUDE_TOOLS and EXCLUDE_TOOLS are specified with tool group names', () => { - vi.stubEnv('INCLUDE_TOOLS', 'datasource'); - vi.stubEnv('EXCLUDE_TOOLS', 'workbook'); + it('should set enableMcpSiteSettings to true when specified', () => { + vi.stubEnv('ENABLE_MCP_SITE_SETTINGS', 'true'); - expect(() => new Config()).toThrow('Cannot include and exclude tools simultaneously'); - }); + const config = new Config(); + expect(config.enableMcpSiteSettings).toBe(true); }); describe('HTTP server config parsing', () => { @@ -581,65 +515,6 @@ describe('Config', () => { }); }); - describe('Bounded context parsing', () => { - it('should set boundedContext to null sets when no project, datasource, or workbook IDs are provided', () => { - const config = new Config(); - expect(config.boundedContext).toEqual({ - projectIds: null, - datasourceIds: null, - workbookIds: null, - tags: null, - }); - }); - - it('should set boundedContext to the specified tags and project, datasource, and workbook IDs when provided', () => { - vi.stubEnv('INCLUDE_PROJECT_IDS', ' 123, 456, 123 '); // spacing is intentional here to test trimming - vi.stubEnv('INCLUDE_DATASOURCE_IDS', '789,101'); - vi.stubEnv('INCLUDE_WORKBOOK_IDS', '112,113'); - vi.stubEnv('INCLUDE_TAGS', 'tag1,tag2'); - - const config = new Config(); - expect(config.boundedContext).toEqual({ - projectIds: new Set(['123', '456']), - datasourceIds: new Set(['789', '101']), - workbookIds: new Set(['112', '113']), - tags: new Set(['tag1', 'tag2']), - }); - }); - - it('should throw error when INCLUDE_PROJECT_IDS is set to an empty string', () => { - vi.stubEnv('INCLUDE_PROJECT_IDS', ''); - - expect(() => new Config()).toThrow( - 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', - ); - }); - - it('should throw error when INCLUDE_DATASOURCE_IDS is set to an empty string', () => { - vi.stubEnv('INCLUDE_DATASOURCE_IDS', ''); - - expect(() => new Config()).toThrow( - 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', - ); - }); - - it('should throw error when INCLUDE_WORKBOOK_IDS is set to an empty string', () => { - vi.stubEnv('INCLUDE_WORKBOOK_IDS', ''); - - expect(() => new Config()).toThrow( - 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', - ); - }); - - it('should throw error when INCLUDE_TAGS is set to an empty string', () => { - vi.stubEnv('INCLUDE_TAGS', ''); - - expect(() => new Config()).toThrow( - 'When set, the environment variable INCLUDE_TAGS must have at least one value', - ); - }); - }); - describe('OAuth configuration', () => { function stubDefaultOAuthEnvVars(): void { vi.stubEnv('OAUTH_ISSUER', 'https://example.com'); @@ -971,65 +846,4 @@ describe('Config', () => { expect(result).toBe(42); }); }); - - describe('Max results limit parsing', () => { - it('should return null when MAX_RESULT_LIMIT and MAX_RESULT_LIMITS are not set', () => { - expect(new Config().getMaxResultLimit('query-datasource')).toBeNull(); - }); - - it('should return the max result limit when MAX_RESULT_LIMITS has a single tool', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100'); - - expect(new Config().getMaxResultLimit('query-datasource')).toEqual(100); - }); - - it('should return the max result limit when MAX_RESULT_LIMITS has a single tool group', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'datasource:200'); - - expect(new Config().getMaxResultLimit('query-datasource')).toEqual(200); - }); - - it('should return the max result limit for the tool when a tool and a tool group are both specified', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100,datasource:200'); - - expect(new Config().getMaxResultLimit('query-datasource')).toEqual(100); - expect(new Config().getMaxResultLimit('list-datasources')).toEqual(200); - }); - - it('should fallback to MAX_RESULT_LIMIT when a tool-specific max result limit is not set', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100'); - vi.stubEnv('MAX_RESULT_LIMIT', '300'); - - expect(new Config().getMaxResultLimit('query-datasource')).toEqual(100); - expect(new Config().getMaxResultLimit('list-datasources')).toEqual(300); - }); - - it('should return null when MAX_RESULT_LIMITS has a non-number', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:abc'); - - const config = new Config(); - expect(config.getMaxResultLimit('query-datasource')).toBe(null); - }); - - it('should return null when MAX_RESULT_LIMIT is specified as a non-number', () => { - vi.stubEnv('MAX_RESULT_LIMIT', 'abc'); - - const config = new Config(); - expect(config.getMaxResultLimit('query-datasource')).toBe(null); - }); - - it('should return null when MAX_RESULT_LIMITS has a negative number', () => { - vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:-100'); - - const config = new Config(); - expect(config.getMaxResultLimit('query-datasource')).toBe(null); - }); - - it('should return null when MAX_RESULT_LIMIT is specified as a negative number', () => { - vi.stubEnv('MAX_RESULT_LIMIT', '-100'); - - const config = new Config(); - expect(config.getMaxResultLimit('query-datasource')).toBe(null); - }); - }); }); diff --git a/src/config.ts b/src/config.ts index 7e15205e..3c788122 100644 --- a/src/config.ts +++ b/src/config.ts @@ -3,7 +3,6 @@ import { existsSync, readFileSync } from 'fs'; import { join } from 'path'; import { isTelemetryProvider, providerConfigSchema, TelemetryConfig } from './telemetry/types.js'; -import { isToolGroupName, isToolName, toolGroups, ToolName } from './tools/toolName.js'; import { isTransport, TransportName } from './transports.js'; import { getDirname } from './utils/getDirname.js'; import invariant from './utils/invariant.js'; @@ -20,20 +19,10 @@ const authTypes = ['pat', 'uat', 'direct-trust', 'oauth'] as const; type AuthType = (typeof authTypes)[number]; function isAuthType(auth: unknown): auth is AuthType { - return !!authTypes.find((type) => type === auth); + return authTypes.some((type) => type === auth); } -export type BoundedContext = { - projectIds: Set | null; - datasourceIds: Set | null; - workbookIds: Set | null; - tags: Set | null; -}; - export class Config { - private maxResultLimit: number | null; - private maxResultLimits: Map | null; - auth: AuthType; server: string; transport: TransportName; @@ -58,16 +47,13 @@ export class Config { datasourceCredentials: string; defaultLogLevel: string; disableLogMasking: boolean; - includeTools: Array; - excludeTools: Array; maxRequestTimeoutMs: number; - disableQueryDatasourceValidationRequests: boolean; - disableMetadataApiRequests: boolean; disableSessionManagement: boolean; enableServerLogging: boolean; serverLogDirectory: string; - boundedContext: BoundedContext; tableauServerVersionCheckIntervalInHours: number; + mcpSiteSettingsCheckIntervalInMinutes: number; + enableMcpSiteSettings: boolean; oauth: { enabled: boolean; issuer: string; @@ -86,10 +72,6 @@ export class Config { productTelemetryEndpoint: string; productTelemetryEnabled: boolean; - getMaxResultLimit(toolName: ToolName): number | null { - return this.maxResultLimits?.get(toolName) ?? this.maxResultLimit; - } - constructor() { const cleansedVars = removeClaudeMcpBundleUserConfigTemplates(process.env); const { @@ -119,21 +101,13 @@ export class Config { DATASOURCE_CREDENTIALS: datasourceCredentials, DEFAULT_LOG_LEVEL: defaultLogLevel, DISABLE_LOG_MASKING: disableLogMasking, - INCLUDE_TOOLS: includeTools, - EXCLUDE_TOOLS: excludeTools, MAX_REQUEST_TIMEOUT_MS: maxRequestTimeoutMs, - MAX_RESULT_LIMIT: maxResultLimit, - MAX_RESULT_LIMITS: maxResultLimits, - DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS: disableQueryDatasourceValidationRequests, - DISABLE_METADATA_API_REQUESTS: disableMetadataApiRequests, DISABLE_SESSION_MANAGEMENT: disableSessionManagement, ENABLE_SERVER_LOGGING: enableServerLogging, SERVER_LOG_DIRECTORY: serverLogDirectory, - INCLUDE_PROJECT_IDS: includeProjectIds, - INCLUDE_DATASOURCE_IDS: includeDatasourceIds, - INCLUDE_WORKBOOK_IDS: includeWorkbookIds, - INCLUDE_TAGS: includeTags, TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: tableauServerVersionCheckIntervalInHours, + MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES: mcpSiteSettingsCheckIntervalInMinutes, + ENABLE_MCP_SITE_SETTINGS: enableMcpSiteSettings, DANGEROUSLY_DISABLE_OAUTH: disableOauth, OAUTH_ISSUER: oauthIssuer, OAUTH_LOCK_SITE: oauthLockSite, @@ -168,42 +142,9 @@ export class Config { this.datasourceCredentials = datasourceCredentials ?? ''; this.defaultLogLevel = defaultLogLevel ?? 'debug'; this.disableLogMasking = disableLogMasking === 'true'; - this.disableQueryDatasourceValidationRequests = - disableQueryDatasourceValidationRequests === 'true'; - this.disableMetadataApiRequests = disableMetadataApiRequests === 'true'; this.disableSessionManagement = disableSessionManagement === 'true'; this.enableServerLogging = enableServerLogging === 'true'; this.serverLogDirectory = serverLogDirectory || join(__dirname, 'logs'); - this.boundedContext = { - projectIds: createSetFromCommaSeparatedString(includeProjectIds), - datasourceIds: createSetFromCommaSeparatedString(includeDatasourceIds), - workbookIds: createSetFromCommaSeparatedString(includeWorkbookIds), - tags: createSetFromCommaSeparatedString(includeTags), - }; - - if (this.boundedContext.projectIds?.size === 0) { - throw new Error( - 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', - ); - } - - if (this.boundedContext.datasourceIds?.size === 0) { - throw new Error( - 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', - ); - } - - if (this.boundedContext.workbookIds?.size === 0) { - throw new Error( - 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', - ); - } - - if (this.boundedContext.tags?.size === 0) { - throw new Error( - 'When set, the environment variable INCLUDE_TAGS must have at least one value', - ); - } this.tableauServerVersionCheckIntervalInHours = parseNumber( tableauServerVersionCheckIntervalInHours, @@ -214,6 +155,16 @@ export class Config { }, ); + this.mcpSiteSettingsCheckIntervalInMinutes = parseNumber( + mcpSiteSettingsCheckIntervalInMinutes, + { + defaultValue: 10, + minValue: 1, + maxValue: 60 * 24, // 24 hours + }, + ); + + this.enableMcpSiteSettings = enableMcpSiteSettings === 'true'; const disableOauthOverride = disableOauth === 'true'; this.oauth = { enabled: disableOauthOverride ? false : !!oauthIssuer, @@ -331,30 +282,6 @@ export class Config { maxValue: ONE_HOUR_IN_MS, }); - const maxResultLimitNumber = maxResultLimit ? parseInt(maxResultLimit) : NaN; - this.maxResultLimit = - isNaN(maxResultLimitNumber) || maxResultLimitNumber <= 0 ? null : maxResultLimitNumber; - - this.maxResultLimits = maxResultLimits ? getMaxResultLimits(maxResultLimits) : null; - - this.includeTools = includeTools - ? includeTools.split(',').flatMap((s) => { - const v = s.trim(); - return isToolName(v) ? v : isToolGroupName(v) ? toolGroups[v] : []; - }) - : []; - - this.excludeTools = excludeTools - ? excludeTools.split(',').flatMap((s) => { - const v = s.trim(); - return isToolName(v) ? v : isToolGroupName(v) ? toolGroups[v] : []; - }) - : []; - - if (this.includeTools.length > 0 && this.excludeTools.length > 0) { - throw new Error('Cannot include and exclude tools simultaneously'); - } - if (this.auth === 'pat') { invariant(patName, 'The environment variable PAT_NAME is not set'); invariant(patValue, 'The environment variable PAT_VALUE is not set'); @@ -481,25 +408,9 @@ function getTrustProxyConfig(trustProxyConfig: string): boolean | number | strin return trustProxyConfig; } -// Creates a set from a comma-separated string of values. -// Returns null if the value is undefined. -function createSetFromCommaSeparatedString(value: string | undefined): Set | null { - if (value === undefined) { - return null; - } - - return new Set( - value - .trim() - .split(',') - .map((id) => id.trim()) - .filter(Boolean), - ); -} - // When the user does not provide a site name in the Claude MCP Bundle configuration, // Claude doesn't replace its value and sets the site name to "${user_config.site_name}". -function removeClaudeMcpBundleUserConfigTemplates( +export function removeClaudeMcpBundleUserConfigTemplates( envVars: Record, ): Record { return Object.entries(envVars).reduce>((acc, [key, value]) => { @@ -512,32 +423,6 @@ function removeClaudeMcpBundleUserConfigTemplates( }, {}); } -function getMaxResultLimits(maxResultLimits: string): Map { - const map = new Map(); - if (!maxResultLimits) { - return map; - } - - maxResultLimits.split(',').forEach((curr) => { - const [toolName, maxResultLimit] = curr.split(':'); - const maxResultLimitNumber = maxResultLimit ? parseInt(maxResultLimit) : NaN; - const actualLimit = - isNaN(maxResultLimitNumber) || maxResultLimitNumber <= 0 ? null : maxResultLimitNumber; - if (isToolName(toolName)) { - map.set(toolName, actualLimit); - } else if (isToolGroupName(toolName)) { - toolGroups[toolName].forEach((toolName) => { - if (!map.has(toolName)) { - // Tool names take precedence over group names - map.set(toolName, actualLimit); - } - }); - } - }); - - return map; -} - function parseNumber( value: string | undefined, { diff --git a/src/overridableConfig.test.ts b/src/overridableConfig.test.ts new file mode 100644 index 00000000..4ef9ae83 --- /dev/null +++ b/src/overridableConfig.test.ts @@ -0,0 +1,332 @@ +import { exportedForTesting } from './overridableConfig.js'; +import { stubDefaultEnvVars } from './testShared.js'; + +describe('OverridableConfig', () => { + const { OverridableConfig } = exportedForTesting; + + beforeEach(() => { + vi.resetModules(); + vi.unstubAllEnvs(); + stubDefaultEnvVars(); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it('should set disableQueryDatasourceValidationRequests to false by default', () => { + const config = new OverridableConfig({}); + expect(config.disableQueryDatasourceValidationRequests).toBe(false); + }); + + it('should set disableQueryDatasourceValidationRequests to true when specified', () => { + vi.stubEnv('DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS', 'true'); + + const config = new OverridableConfig({}); + expect(config.disableQueryDatasourceValidationRequests).toBe(true); + }); + + it('should set disableMetadataApiRequests to false by default', () => { + const config = new OverridableConfig({}); + expect(config.disableMetadataApiRequests).toBe(false); + }); + + it('should set disableMetadataApiRequests to true when specified', () => { + vi.stubEnv('DISABLE_METADATA_API_REQUESTS', 'true'); + + const config = new OverridableConfig({}); + expect(config.disableMetadataApiRequests).toBe(true); + }); + + describe('Tool filtering', () => { + it('should set empty arrays for includeTools and excludeTools when not specified', () => { + const config = new OverridableConfig({}); + expect(config.includeTools).toEqual([]); + expect(config.excludeTools).toEqual([]); + }); + + it('should parse INCLUDE_TOOLS into an array of valid tool names', () => { + vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,get-datasource-metadata'); + + const config = new OverridableConfig({}); + expect(config.includeTools).toEqual(['query-datasource', 'get-datasource-metadata']); + }); + + it('should parse INCLUDE_TOOLS into an array of valid tool names when tool group names are used', () => { + vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,workbook'); + + const config = new OverridableConfig({}); + expect(config.includeTools).toEqual(['query-datasource', 'list-workbooks', 'get-workbook']); + }); + + it('should parse EXCLUDE_TOOLS into an array of valid tool names', () => { + vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource'); + + const config = new OverridableConfig({}); + expect(config.excludeTools).toEqual(['query-datasource']); + }); + + it('should parse EXCLUDE_TOOLS into an array of valid tool names when tool group names are used', () => { + vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource,workbook'); + + const config = new OverridableConfig({}); + expect(config.excludeTools).toEqual(['query-datasource', 'list-workbooks', 'get-workbook']); + }); + + it('should filter out invalid tool names from INCLUDE_TOOLS', () => { + vi.stubEnv('INCLUDE_TOOLS', 'query-datasource,order-hamburgers'); + + const config = new OverridableConfig({}); + expect(config.includeTools).toEqual(['query-datasource']); + }); + + it('should filter out invalid tool names from EXCLUDE_TOOLS', () => { + vi.stubEnv('EXCLUDE_TOOLS', 'query-datasource,order-hamburgers'); + + const config = new OverridableConfig({}); + expect(config.excludeTools).toEqual(['query-datasource']); + }); + + it('should throw error when both INCLUDE_TOOLS and EXCLUDE_TOOLS are specified', () => { + vi.stubEnv('INCLUDE_TOOLS', 'query-datasource'); + vi.stubEnv('EXCLUDE_TOOLS', 'get-datasource-metadata'); + + expect(() => new OverridableConfig({})).toThrow( + 'Cannot include and exclude tools simultaneously', + ); + }); + + it('should throw error when both INCLUDE_TOOLS and EXCLUDE_TOOLS are specified with tool group names', () => { + vi.stubEnv('INCLUDE_TOOLS', 'datasource'); + vi.stubEnv('EXCLUDE_TOOLS', 'workbook'); + expect(() => new OverridableConfig({})).toThrow( + 'Cannot include and exclude tools simultaneously', + ); + }); + }); + + describe('Bounded context parsing', () => { + it('should set boundedContext to null sets when no project, datasource, or workbook IDs are provided', () => { + const config = new OverridableConfig({}); + expect(config.boundedContext).toEqual({ + projectIds: null, + datasourceIds: null, + workbookIds: null, + tags: null, + }); + }); + + it('should set boundedContext to the specified tags and project, datasource, and workbook IDs when provided', () => { + vi.stubEnv('INCLUDE_PROJECT_IDS', ' 123, 456, 123 '); // spacing is intentional here to test trimming + vi.stubEnv('INCLUDE_DATASOURCE_IDS', '789,101'); + vi.stubEnv('INCLUDE_WORKBOOK_IDS', '112,113'); + vi.stubEnv('INCLUDE_TAGS', 'tag1,tag2'); + + const config = new OverridableConfig({}); + expect(config.boundedContext).toEqual({ + projectIds: new Set(['123', '456']), + datasourceIds: new Set(['789', '101']), + workbookIds: new Set(['112', '113']), + tags: new Set(['tag1', 'tag2']), + }); + }); + + it('should throw error when INCLUDE_PROJECT_IDS is set to an empty string', () => { + vi.stubEnv('INCLUDE_PROJECT_IDS', ''); + + expect(() => new OverridableConfig({})).toThrow( + 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', + ); + }); + + it('should throw error when INCLUDE_DATASOURCE_IDS is set to an empty string', () => { + vi.stubEnv('INCLUDE_DATASOURCE_IDS', ''); + + expect(() => new OverridableConfig({})).toThrow( + 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', + ); + }); + + it('should throw error when INCLUDE_WORKBOOK_IDS is set to an empty string', () => { + vi.stubEnv('INCLUDE_WORKBOOK_IDS', ''); + + expect(() => new OverridableConfig({})).toThrow( + 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', + ); + }); + + it('should throw error when INCLUDE_TAGS is set to an empty string', () => { + vi.stubEnv('INCLUDE_TAGS', ''); + + expect(() => new OverridableConfig({})).toThrow( + 'When set, the environment variable INCLUDE_TAGS must have at least one value', + ); + }); + }); + + describe('Max results limit parsing', () => { + it('should return null when MAX_RESULT_LIMIT and MAX_RESULT_LIMITS are not set', () => { + expect(new OverridableConfig({}).getMaxResultLimit('query-datasource')).toBeNull(); + }); + + it('should return the max result limit when MAX_RESULT_LIMITS has a single tool', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100'); + + expect(new OverridableConfig({}).getMaxResultLimit('query-datasource')).toEqual(100); + }); + + it('should return the max result limit when MAX_RESULT_LIMITS has a single tool group', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'datasource:200'); + + expect(new OverridableConfig({}).getMaxResultLimit('query-datasource')).toEqual(200); + }); + + it('should return the max result limit for the tool when a tool and a tool group are both specified', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100,datasource:200'); + + expect(new OverridableConfig({}).getMaxResultLimit('query-datasource')).toEqual(100); + expect(new OverridableConfig({}).getMaxResultLimit('list-datasources')).toEqual(200); + }); + + it('should fallback to MAX_RESULT_LIMIT when a tool-specific max result limit is not set', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100'); + vi.stubEnv('MAX_RESULT_LIMIT', '300'); + + expect(new OverridableConfig({}).getMaxResultLimit('query-datasource')).toEqual(100); + expect(new OverridableConfig({}).getMaxResultLimit('list-datasources')).toEqual(300); + }); + + it('should return null when MAX_RESULT_LIMITS has a non-number', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:abc'); + + const config = new OverridableConfig({}); + expect(config.getMaxResultLimit('query-datasource')).toBe(null); + }); + + it('should return null when MAX_RESULT_LIMIT is specified as a non-number', () => { + vi.stubEnv('MAX_RESULT_LIMIT', 'abc'); + + const config = new OverridableConfig({}); + expect(config.getMaxResultLimit('query-datasource')).toBe(null); + }); + + it('should return null when MAX_RESULT_LIMITS has a negative number', () => { + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:-100'); + + const config = new OverridableConfig({}); + expect(config.getMaxResultLimit('query-datasource')).toBe(null); + }); + + it('should return null when MAX_RESULT_LIMIT is specified as a negative number', () => { + vi.stubEnv('MAX_RESULT_LIMIT', '-100'); + + const config = new OverridableConfig({}); + expect(config.getMaxResultLimit('query-datasource')).toBe(null); + }); + }); + + describe('Override behavior', () => { + it('should override INCLUDE_TOOLS', () => { + vi.stubEnv('INCLUDE_TOOLS', 'list-views'); + + const config = new OverridableConfig({ + INCLUDE_TOOLS: 'query-datasource', + }); + + expect(config.includeTools).toEqual(['query-datasource']); + }); + + it('should override EXCLUDE_TOOLS', () => { + vi.stubEnv('EXCLUDE_TOOLS', 'list-views'); + + const config = new OverridableConfig({ + EXCLUDE_TOOLS: 'get-datasource-metadata', + }); + + expect(config.excludeTools).toEqual(['get-datasource-metadata']); + }); + + it('should override INCLUDE_PROJECT_IDS', () => { + vi.stubEnv('INCLUDE_PROJECT_IDS', '999'); + + const config = new OverridableConfig({ + INCLUDE_PROJECT_IDS: '123,456', + }); + + expect(config.boundedContext.projectIds).toEqual(new Set(['123', '456'])); + }); + + it('should override INCLUDE_DATASOURCE_IDS', () => { + vi.stubEnv('INCLUDE_DATASOURCE_IDS', '999'); + + const config = new OverridableConfig({ + INCLUDE_DATASOURCE_IDS: '123,456', + }); + + expect(config.boundedContext.datasourceIds).toEqual(new Set(['123', '456'])); + }); + + it('should override INCLUDE_WORKBOOK_IDS', () => { + vi.stubEnv('INCLUDE_WORKBOOK_IDS', '999'); + + const config = new OverridableConfig({ + INCLUDE_WORKBOOK_IDS: '123,456', + }); + + expect(config.boundedContext.workbookIds).toEqual(new Set(['123', '456'])); + }); + + it('should override INCLUDE_TAGS', () => { + vi.stubEnv('INCLUDE_TAGS', '999'); + + const config = new OverridableConfig({ + INCLUDE_TAGS: '123,456', + }); + + expect(config.boundedContext.tags).toEqual(new Set(['123', '456'])); + }); + + it('should override MAX_RESULT_LIMIT', () => { + vi.stubEnv('MAX_RESULT_LIMIT', '10'); + + const config = new OverridableConfig({ + MAX_RESULT_LIMIT: '99', + }); + + expect(config.getMaxResultLimit('query-datasource')).toEqual(99); + }); + + it('should override MAX_RESULT_LIMITS', () => { + vi.stubEnv('MAX_RESULT_LIMIT', '10'); + vi.stubEnv('MAX_RESULT_LIMITS', 'query-datasource:100'); + + const config = new OverridableConfig({ + MAX_RESULT_LIMIT: '99', + MAX_RESULT_LIMITS: 'query-datasource:999', + }); + + expect(config.getMaxResultLimit('list-datasources')).toEqual(99); + expect(config.getMaxResultLimit('query-datasource')).toEqual(999); + }); + + it('should override DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS', () => { + vi.stubEnv('DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS', 'false'); + + const config = new OverridableConfig({ + DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS: 'true', + }); + + expect(config.disableQueryDatasourceValidationRequests).toEqual(true); + }); + + it('should override DISABLE_METADATA_API_REQUESTS', () => { + vi.stubEnv('DISABLE_METADATA_API_REQUESTS', 'false'); + + const config = new OverridableConfig({ + DISABLE_METADATA_API_REQUESTS: 'true', + }); + + expect(config.disableMetadataApiRequests).toEqual(true); + }); + }); +}); diff --git a/src/overridableConfig.ts b/src/overridableConfig.ts new file mode 100644 index 00000000..e1a9333d --- /dev/null +++ b/src/overridableConfig.ts @@ -0,0 +1,182 @@ +import { ProcessEnvEx } from '../types/process-env.js'; +import { removeClaudeMcpBundleUserConfigTemplates } from './config.js'; +import { isToolGroupName, isToolName, toolGroups, ToolName } from './tools/toolName.js'; + +const overridableVariables = [ + 'INCLUDE_TOOLS', + 'EXCLUDE_TOOLS', + 'INCLUDE_PROJECT_IDS', + 'INCLUDE_DATASOURCE_IDS', + 'INCLUDE_WORKBOOK_IDS', + 'INCLUDE_TAGS', + 'MAX_RESULT_LIMIT', + 'MAX_RESULT_LIMITS', + 'DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS', + 'DISABLE_METADATA_API_REQUESTS', +] as const satisfies ReadonlyArray; + +type OverridableVariable = (typeof overridableVariables)[number]; +function isOverridableVariable(variable: unknown): variable is OverridableVariable { + return overridableVariables.some((v) => v === variable); +} + +function filterEnvVarsToOverridable( + environmentVariables: Record, +): Record { + return Object.fromEntries( + Object.entries(environmentVariables).filter(([key]) => isOverridableVariable(key)), + ) as Record; +} + +export type BoundedContext = { + projectIds: Set | null; + datasourceIds: Set | null; + workbookIds: Set | null; + tags: Set | null; +}; + +export class OverridableConfig { + private maxResultLimit: number | null; + private maxResultLimits: Map | null; + + includeTools: Array; + excludeTools: Array; + + disableQueryDatasourceValidationRequests: boolean; + disableMetadataApiRequests: boolean; + + boundedContext: BoundedContext; + + getMaxResultLimit(toolName: ToolName): number | null { + return this.maxResultLimits?.get(toolName) ?? this.maxResultLimit; + } + + constructor(overrides: Record | undefined) { + const cleansedVars = removeClaudeMcpBundleUserConfigTemplates({ + ...process.env, + ...(overrides ? filterEnvVarsToOverridable(overrides) : {}), + }); + + const { + INCLUDE_TOOLS: includeTools, + EXCLUDE_TOOLS: excludeTools, + MAX_RESULT_LIMIT: maxResultLimit, + MAX_RESULT_LIMITS: maxResultLimits, + DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS: disableQueryDatasourceValidationRequests, + DISABLE_METADATA_API_REQUESTS: disableMetadataApiRequests, + INCLUDE_PROJECT_IDS: includeProjectIds, + INCLUDE_DATASOURCE_IDS: includeDatasourceIds, + INCLUDE_WORKBOOK_IDS: includeWorkbookIds, + INCLUDE_TAGS: includeTags, + } = cleansedVars; + + this.disableQueryDatasourceValidationRequests = + disableQueryDatasourceValidationRequests === 'true'; + this.disableMetadataApiRequests = disableMetadataApiRequests === 'true'; + + this.boundedContext = { + projectIds: createSetFromCommaSeparatedString(includeProjectIds), + datasourceIds: createSetFromCommaSeparatedString(includeDatasourceIds), + workbookIds: createSetFromCommaSeparatedString(includeWorkbookIds), + tags: createSetFromCommaSeparatedString(includeTags), + }; + + if (this.boundedContext.projectIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', + ); + } + + if (this.boundedContext.datasourceIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', + ); + } + + if (this.boundedContext.workbookIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', + ); + } + + if (this.boundedContext.tags?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_TAGS must have at least one value', + ); + } + + const maxResultLimitNumber = maxResultLimit ? parseInt(maxResultLimit) : NaN; + this.maxResultLimit = + isNaN(maxResultLimitNumber) || maxResultLimitNumber <= 0 ? null : maxResultLimitNumber; + + this.maxResultLimits = maxResultLimits ? getMaxResultLimits(maxResultLimits) : null; + + this.includeTools = includeTools + ? includeTools.split(',').flatMap((s) => { + const v = s.trim(); + return isToolName(v) ? v : isToolGroupName(v) ? toolGroups[v] : []; + }) + : []; + + this.excludeTools = excludeTools + ? excludeTools.split(',').flatMap((s) => { + const v = s.trim(); + return isToolName(v) ? v : isToolGroupName(v) ? toolGroups[v] : []; + }) + : []; + + if (this.includeTools.length > 0 && this.excludeTools.length > 0) { + throw new Error('Cannot include and exclude tools simultaneously'); + } + } +} + +// Creates a set from a comma-separated string of values. +// Returns null if the value is undefined. +function createSetFromCommaSeparatedString(value: string | undefined): Set | null { + if (value === undefined) { + return null; + } + + return new Set( + value + .trim() + .split(',') + .map((id) => id.trim()) + .filter(Boolean), + ); +} + +function getMaxResultLimits(maxResultLimits: string): Map { + const map = new Map(); + if (!maxResultLimits) { + return map; + } + + maxResultLimits.split(',').forEach((curr) => { + const [toolName, maxResultLimit] = curr.split(':'); + const maxResultLimitNumber = maxResultLimit ? parseInt(maxResultLimit) : NaN; + const actualLimit = + isNaN(maxResultLimitNumber) || maxResultLimitNumber <= 0 ? null : maxResultLimitNumber; + if (isToolName(toolName)) { + map.set(toolName, actualLimit); + } else if (isToolGroupName(toolName)) { + toolGroups[toolName].forEach((toolName) => { + if (!map.has(toolName)) { + // Tool names take precedence over group names + map.set(toolName, actualLimit); + } + }); + } + }); + + return map; +} + +export const getOverridableConfig = ( + overrides: Record | undefined, +): OverridableConfig => new OverridableConfig(overrides); + +export const exportedForTesting = { + OverridableConfig: OverridableConfig, +}; diff --git a/src/restApiInstance.ts b/src/restApiInstance.ts index 50ba844d..a1f519a6 100644 --- a/src/restApiInstance.ts +++ b/src/restApiInstance.ts @@ -28,31 +28,49 @@ type JwtScopes = | 'tableau:metric_subscriptions:read' | 'tableau:insights:read' | 'tableau:views:download' - | 'tableau:insight_brief:create'; + | 'tableau:insight_brief:create' + | 'tableau:mcp_site_settings:read'; + +export type RestApiArgs = { + config: Config; + server: Server; + signal: AbortSignal; + authInfo?: TableauAuthInfo; +} & ( + | { + requestId: RequestId; + disableLogging?: false; + } + | { + disableLogging: true; + } +); const getNewRestApiInstanceAsync = async ( - config: Config, - requestId: RequestId, - server: Server, - jwtScopes: Set, - signal: AbortSignal, - authInfo?: TableauAuthInfo, + args: RestApiArgs & { + jwtScopes: Set; + }, ): Promise => { - signal.addEventListener( - 'abort', - () => { - log.info( - server, - { - type: 'request-cancelled', - requestId, - reason: signal.reason, - }, - { logger: server.name, requestId }, - ); - }, - { once: true }, - ); + const { config, server, jwtScopes, signal, authInfo, disableLogging } = args; + + if (!disableLogging) { + const { requestId } = args; + signal.addEventListener( + 'abort', + () => { + log.info( + server, + { + type: 'request-cancelled', + requestId, + reason: signal.reason, + }, + { logger: server.name, requestId }, + ); + }, + { once: true }, + ); + } const tableauServer = config.server || authInfo?.server; invariant(tableauServer, 'Tableau server could not be determined'); @@ -60,14 +78,18 @@ const getNewRestApiInstanceAsync = async ( const restApi = new RestApi(tableauServer, { maxRequestTimeoutMs: config.maxRequestTimeoutMs, signal, - requestInterceptor: [ - getRequestInterceptor(server, requestId), - getRequestErrorInterceptor(server, requestId), - ], - responseInterceptor: [ - getResponseInterceptor(server, requestId), - getResponseErrorInterceptor(server, requestId), - ], + requestInterceptor: disableLogging + ? undefined + : [ + getRequestInterceptor(server, args.requestId), + getRequestErrorInterceptor(server, args.requestId), + ], + responseInterceptor: disableLogging + ? undefined + : [ + getResponseInterceptor(server, args.requestId), + getResponseErrorInterceptor(server, args.requestId), + ], }); if (config.auth === 'pat') { @@ -112,31 +134,18 @@ const getNewRestApiInstanceAsync = async ( return restApi; }; -export const useRestApi = async ({ - config, - requestId, - server, - callback, - jwtScopes, - signal, - authInfo, -}: { - config: Config; - requestId: RequestId; - server: Server; - jwtScopes: Array; - signal: AbortSignal; - callback: (restApi: RestApi) => Promise; - authInfo?: TableauAuthInfo; -}): Promise => { - const restApi = await getNewRestApiInstanceAsync( - config, - requestId, - server, - new Set(jwtScopes), - signal, - authInfo, - ); +export const useRestApi = async ( + args: RestApiArgs & { + jwtScopes: Array; + callback: (restApi: RestApi) => Promise; + }, +): Promise => { + const { callback, ...remaining } = args; + const { config } = remaining; + const restApi = await getNewRestApiInstanceAsync({ + ...remaining, + jwtScopes: new Set(args.jwtScopes), + }); try { return await callback(restApi); } finally { diff --git a/src/scripts/createClaudeMcpBundleManifest.ts b/src/scripts/createClaudeMcpBundleManifest.ts index 5d13e12c..0c3a18b7 100644 --- a/src/scripts/createClaudeMcpBundleManifest.ts +++ b/src/scripts/createClaudeMcpBundleManifest.ts @@ -365,6 +365,22 @@ const envVars = { required: false, sensitive: false, }, + MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES: { + includeInUserConfig: false, + type: 'number', + title: 'MCP Site Settings Check Interval in Minutes', + description: 'The interval in minutes to check the MCP site settings.', + required: false, + sensitive: false, + }, + ENABLE_MCP_SITE_SETTINGS: { + includeInUserConfig: false, + type: 'boolean', + title: 'Enable MCP Site Settings', + description: 'Enable MCP site settings.', + required: false, + sensitive: false, + }, DANGEROUSLY_DISABLE_OAUTH: { includeInUserConfig: false, type: 'boolean', diff --git a/src/sdks/tableau/restApi.ts b/src/sdks/tableau/restApi.ts index d83bd863..69ab6876 100644 --- a/src/sdks/tableau/restApi.ts +++ b/src/sdks/tableau/restApi.ts @@ -21,6 +21,7 @@ import ViewsMethods from './methods/viewsMethods.js'; import VizqlDataServiceMethods from './methods/vizqlDataServiceMethods.js'; import WorkbooksMethods from './methods/workbooksMethods.js'; import { Credentials } from './types/credentials.js'; +import { McpSiteSettings } from './types/mcpSiteSettings.js'; /** * Interface for the Tableau REST APIs @@ -186,6 +187,19 @@ export class RestApi { return this._serverMethods; } + get siteMethods(): { getMcpSettings: () => Promise } { + return { + getMcpSettings: async (): Promise => { + // When the "Get MCP Site Settings" REST API is available: + // 1. Remove this comment. + // 2. Default enableMcpSiteSettings to enabled. + // 3. Add documentation for ENABLE_MCP_SITE_SETTINGS. + // 4. Add documentation for MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES. + return {}; + }, + }; + } + get vizqlDataServiceMethods(): VizqlDataServiceMethods { if (!this._vizqlDataServiceMethods) { const baseUrl = `${this._host}/api/v1/vizql-data-service`; diff --git a/src/sdks/tableau/types/mcpSiteSettings.ts b/src/sdks/tableau/types/mcpSiteSettings.ts new file mode 100644 index 00000000..f0db5179 --- /dev/null +++ b/src/sdks/tableau/types/mcpSiteSettings.ts @@ -0,0 +1,4 @@ +import { z } from 'zod'; + +export const mcpSiteSettingsSchema = z.record(z.string(), z.string()); +export type McpSiteSettings = z.infer; diff --git a/src/server.ts b/src/server.ts index e92528a4..fe670072 100644 --- a/src/server.ts +++ b/src/server.ts @@ -2,12 +2,12 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { InitializeRequest, SetLevelRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import pkg from '../package.json'; -import { getConfig } from './config.js'; import { setLogLevel } from './logging/log.js'; import { TableauAuthInfo } from './server/oauth/schemas.js'; import { Tool } from './tools/tool.js'; import { toolNames } from './tools/toolName.js'; import { toolFactories } from './tools/tools.js'; +import { getConfigWithOverrides } from './utils/mcpSiteSettings'; import { Provider } from './utils/provider.js'; export const serverName = 'tableau-mcp'; @@ -60,7 +60,7 @@ export class Server extends McpServer { paramsSchema, annotations, callback, - } of this._getToolsToRegister(authInfo)) { + } of await this._getToolsToRegister(authInfo)) { this.registerTool( name, { @@ -80,8 +80,16 @@ export class Server extends McpServer { }); }; - private _getToolsToRegister = (authInfo?: TableauAuthInfo): Array> => { - const { includeTools, excludeTools } = getConfig(); + private _getToolsToRegister = async (authInfo?: TableauAuthInfo): Promise>> => { + const config = await getConfigWithOverrides({ + restApiArgs: { + server: this, + authInfo, + disableLogging: true, // MCP server is not connected yet so we can't send logging notifications + }, + }); + + const { includeTools, excludeTools } = config; const tools = toolFactories.map((toolFactory) => toolFactory(this, authInfo)); const toolsToRegister = tools.filter((tool) => { diff --git a/src/tools/contentExploration/searchContent.ts b/src/tools/contentExploration/searchContent.ts index a99444f4..5a2e9a70 100644 --- a/src/tools/contentExploration/searchContent.ts +++ b/src/tools/contentExploration/searchContent.ts @@ -11,6 +11,7 @@ import { import { Server } from '../../server.js'; import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { Tool } from '../tool.js'; import { buildFilterString, @@ -67,6 +68,18 @@ This tool searches across all supported content types for objects relevant to th { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + const orderByString = orderBy ? buildOrderByString(orderBy) : undefined; const filterString = filter ? buildFilterString(filter) : undefined; return await searchContentTool.logAndExecute>({ @@ -77,14 +90,13 @@ This tool searches across all supported content types for objects relevant to th callback: async () => { return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - const maxResultLimit = config.getMaxResultLimit(searchContentTool.name); + const maxResultLimit = configWithOverrides.getMaxResultLimit( + searchContentTool.name, + ); + const response = await restApi.contentExplorationMethods.searchContent({ terms, page: 0, @@ -97,8 +109,12 @@ This tool searches across all supported content types for objects relevant to th }), ); }, - constrainSuccessResult: (items) => - constrainSearchContent({ items, boundedContext: config.boundedContext }), + constrainSuccessResult: async (items) => { + return constrainSearchContent({ + items, + boundedContext: configWithOverrides.boundedContext, + }); + }, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, siteLuid: getSiteLuidFromAccessToken(getTableauAuthInfo(authInfo)?.accessToken), diff --git a/src/tools/contentExploration/searchContentUtils.ts b/src/tools/contentExploration/searchContentUtils.ts index e8919cc7..04bcbae9 100644 --- a/src/tools/contentExploration/searchContentUtils.ts +++ b/src/tools/contentExploration/searchContentUtils.ts @@ -1,4 +1,4 @@ -import { BoundedContext } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { OrderBy, SearchContentFilter, diff --git a/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts b/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts index 08fdc5ba..4cb8ee15 100644 --- a/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts +++ b/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts @@ -8,6 +8,7 @@ import { GraphQLResponse } from '../../sdks/tableau/apis/metadataApi.js'; import { Server } from '../../server.js'; import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; @@ -127,9 +128,21 @@ export const getGetDatasourceMetadataTool = (server: Server): Tool { + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + const isDatasourceAllowedResult = await resourceAccessChecker.isDatasourceAllowed({ datasourceLuid, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }); if (!isDatasourceAllowedResult.allowed) { @@ -140,12 +153,8 @@ export const getGetDatasourceMetadataTool = (server: Server): Tool { // Fetching metadata from VizQL Data Service API. const readMetadataResult = await restApi.vizqlDataServiceMethods.readMetadata({ @@ -158,7 +167,7 @@ export const getGetDatasourceMetadataTool = (server: Server): Tool => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + const validatedFilter = filter ? parseAndValidateDatasourcesFilterString(filter) : undefined; return await listDatasourcesTool.logAndExecute({ requestId, @@ -92,14 +106,13 @@ export const getListDatasourcesTool = (server: Server): Tool { const datasources = await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - const maxResultLimit = config.getMaxResultLimit(listDatasourcesTool.name); + const maxResultLimit = configWithOverrides.getMaxResultLimit( + listDatasourcesTool.name, + ); + const datasources = await paginate({ pageConfig: { pageSize, @@ -126,8 +139,12 @@ export const getListDatasourcesTool = (server: Server): Tool - constrainDatasources({ datasources, boundedContext: config.boundedContext }), + constrainSuccessResult: async (datasources) => { + return constrainDatasources({ + datasources, + boundedContext: configWithOverrides.boundedContext, + }); + }, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, siteLuid: getSiteLuidFromAccessToken(getTableauAuthInfo(authInfo)?.accessToken), diff --git a/src/tools/pulse/constrainPulseDefinitions.ts b/src/tools/pulse/constrainPulseDefinitions.ts index 300de13a..aa03dd21 100644 --- a/src/tools/pulse/constrainPulseDefinitions.ts +++ b/src/tools/pulse/constrainPulseDefinitions.ts @@ -1,4 +1,4 @@ -import { BoundedContext } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { PulseMetricDefinition } from '../../sdks/tableau/types/pulse.js'; import { ConstrainedResult } from '../tool.js'; diff --git a/src/tools/pulse/constrainPulseMetrics.ts b/src/tools/pulse/constrainPulseMetrics.ts index ad0abf61..50c0632d 100644 --- a/src/tools/pulse/constrainPulseMetrics.ts +++ b/src/tools/pulse/constrainPulseMetrics.ts @@ -1,4 +1,4 @@ -import { BoundedContext } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { PulseMetric } from '../../sdks/tableau/types/pulse.js'; import { ConstrainedResult } from '../tool.js'; diff --git a/src/tools/pulse/generateInsightBrief/generatePulseInsightBriefTool.ts b/src/tools/pulse/generateInsightBrief/generatePulseInsightBriefTool.ts index 1059bfc9..224b3e45 100644 --- a/src/tools/pulse/generateInsightBrief/generatePulseInsightBriefTool.ts +++ b/src/tools/pulse/generateInsightBrief/generatePulseInsightBriefTool.ts @@ -11,6 +11,7 @@ import { import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { Tool } from '../../tool.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -198,6 +199,15 @@ An insight brief is an AI-generated response to questions about Pulse metrics. I { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + const configWithOverrides = await getConfigWithOverrides({ restApiArgs }); + return await generatePulseInsightBriefTool.logAndExecute< PulseInsightBriefResponse, GeneratePulseInsightBriefError @@ -208,7 +218,7 @@ An insight brief is an AI-generated response to questions about Pulse metrics. I args: { briefRequest }, callback: async () => { // Filter out metrics that are not in the allowed datasource set - const { datasourceIds } = config.boundedContext; + const { datasourceIds } = configWithOverrides.boundedContext; if (datasourceIds) { for (const message of briefRequest.messages) { if (message.metric_group_context) { @@ -232,12 +242,8 @@ An insight brief is an AI-generated response to questions about Pulse metrics. I } const result = await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_brief:create'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => await restApi.pulseMethods.generatePulseInsightBrief(briefRequest), }); diff --git a/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts b/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts index be057fcf..b377a65e 100644 --- a/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts +++ b/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts @@ -13,6 +13,7 @@ import { import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { Tool } from '../../tool.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -165,7 +166,16 @@ Generate an insight bundle for the current aggregated value for Pulse Metric usi authInfo, args: { bundleRequest, bundleType }, callback: async () => { - const { datasourceIds } = config.boundedContext; + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + const configWithOverrides = await getConfigWithOverrides({ restApiArgs }); + + const { datasourceIds } = configWithOverrides.boundedContext; if (datasourceIds) { const datasourceLuid = bundleRequest.bundle_request.input.metric.definition.datasource.id; @@ -183,12 +193,8 @@ Generate an insight bundle for the current aggregated value for Pulse Metric usi } const result = await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insights:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => await restApi.pulseMethods.generatePulseMetricValueInsightBundle( bundleRequest, diff --git a/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts b/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts index 761d8d44..b2298d71 100644 --- a/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts +++ b/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts @@ -11,6 +11,7 @@ import { import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { pulsePaginate } from '../../../utils/paginate.js'; import { Tool } from '../../tool.js'; import { constrainPulseDefinitions } from '../constrainPulseDefinitions.js'; @@ -63,6 +64,18 @@ Retrieves a list of all published Pulse Metric Definitions using the Tableau RES { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + return await listAllPulseMetricDefinitionsTool.logAndExecute({ requestId, sessionId, @@ -70,16 +83,13 @@ Retrieves a list of all published Pulse Metric Definitions using the Tableau RES args: { view, limit, pageSize }, callback: async () => { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_definitions_metrics:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - const maxResultLimit = config.getMaxResultLimit( + const maxResultLimit = configWithOverrides.getMaxResultLimit( listAllPulseMetricDefinitionsTool.name, ); + const definitions = await pulsePaginate({ config: { limit: maxResultLimit @@ -108,8 +118,12 @@ Retrieves a list of all published Pulse Metric Definitions using the Tableau RES }, }); }, - constrainSuccessResult: (definitions: Array) => - constrainPulseDefinitions({ definitions, boundedContext: config.boundedContext }), + constrainSuccessResult: async (definitions: Array) => { + return constrainPulseDefinitions({ + definitions, + boundedContext: configWithOverrides.boundedContext, + }); + }, getErrorText: getPulseDisabledError, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, diff --git a/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts b/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts index 38aa4fd9..b8573ec6 100644 --- a/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts +++ b/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts @@ -7,6 +7,7 @@ import { pulseMetricDefinitionViewEnum } from '../../../sdks/tableau/types/pulse import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { Tool } from '../../tool.js'; import { constrainPulseDefinitions } from '../constrainPulseDefinitions.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -63,6 +64,14 @@ Retrieves a list of specific Pulse Metric Definitions using the Tableau REST API { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + return await listPulseMetricDefinitionsFromDefinitionIdsTool.logAndExecute({ requestId, sessionId, @@ -70,12 +79,8 @@ Retrieves a list of specific Pulse Metric Definitions using the Tableau REST API args: { metricDefinitionIds, view }, callback: async () => { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_definitions_metrics:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.pulseMethods.listPulseMetricDefinitionsFromMetricDefinitionIds( metricDefinitionIds, @@ -84,8 +89,16 @@ Retrieves a list of specific Pulse Metric Definitions using the Tableau REST API }, }); }, - constrainSuccessResult: (definitions) => - constrainPulseDefinitions({ definitions, boundedContext: config.boundedContext }), + constrainSuccessResult: async (definitions) => { + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + + return constrainPulseDefinitions({ + definitions, + boundedContext: configWithOverrides.boundedContext, + }); + }, getErrorText: getPulseDisabledError, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, diff --git a/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts b/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts index 5c330115..1ec8c1ed 100644 --- a/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts +++ b/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts @@ -1,13 +1,14 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { BoundedContext, getConfig } from '../../../config.js'; -import { useRestApi } from '../../../restApiInstance.js'; +import { getConfig } from '../../../config.js'; +import { BoundedContext } from '../../../overridableConfig.js'; +import { RestApiArgs, useRestApi } from '../../../restApiInstance.js'; import { PulseMetricSubscription } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getExceptionMessage } from '../../../utils/getExceptionMessage.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; -import { RestApiArgs } from '../../resourceAccessChecker.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { ConstrainedResult, Tool } from '../../tool.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -37,6 +38,14 @@ Retrieves a list of published Pulse Metric Subscriptions for the current user us }, callback: async (_, { requestId, sessionId, authInfo, signal }): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + return await listPulseMetricSubscriptionsTool.logAndExecute({ requestId, sessionId, @@ -44,22 +53,22 @@ Retrieves a list of published Pulse Metric Subscriptions for the current user us args: {}, callback: async () => { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:metric_subscriptions:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.pulseMethods.listPulseMetricSubscriptionsForCurrentUser(); }, }); }, constrainSuccessResult: async (subscriptions) => { + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + return await constrainPulseMetricSubscriptions({ subscriptions, - boundedContext: config.boundedContext, - restApiArgs: { config, requestId, server, signal }, + boundedContext: configWithOverrides.boundedContext, + restApiArgs, }); }, getErrorText: getPulseDisabledError, @@ -103,14 +112,10 @@ export async function constrainPulseMetricSubscriptions({ }; } - const { config, requestId, server, signal } = restApiArgs; try { const metricsResult = await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_metrics:read'], - signal, callback: async (restApi) => { return await restApi.pulseMethods.listPulseMetricsFromMetricIds( subscriptions.map((subscription) => subscription.metric_id), diff --git a/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts b/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts index 45a1fae0..35205d1d 100644 --- a/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts +++ b/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts @@ -8,6 +8,7 @@ import { PulseMetric } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { Tool } from '../../tool.js'; import { constrainPulseMetrics } from '../constrainPulseMetrics.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -42,6 +43,14 @@ Retrieves a list of published Pulse Metrics from a Pulse Metric Definition using { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + return await listPulseMetricsFromMetricDefinitionIdTool.logAndExecute< Array, PulseDisabledError @@ -52,12 +61,8 @@ Retrieves a list of published Pulse Metrics from a Pulse Metric Definition using args: { pulseMetricDefinitionID }, callback: async () => { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_definitions_metrics:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.pulseMethods.listPulseMetricsFromMetricDefinitionId( pulseMetricDefinitionID, @@ -65,8 +70,16 @@ Retrieves a list of published Pulse Metrics from a Pulse Metric Definition using }, }); }, - constrainSuccessResult: (metrics) => - constrainPulseMetrics({ metrics, boundedContext: config.boundedContext }), + constrainSuccessResult: async (metrics) => { + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + + return constrainPulseMetrics({ + metrics, + boundedContext: configWithOverrides.boundedContext, + }); + }, getErrorText: getPulseDisabledError, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, diff --git a/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts b/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts index ffc4f278..75ffd00e 100644 --- a/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts +++ b/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts @@ -6,6 +6,7 @@ import { useRestApi } from '../../../restApiInstance.js'; import { Server } from '../../../server.js'; import { getTableauAuthInfo } from '../../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../../utils/mcpSiteSettings.js'; import { Tool } from '../../tool.js'; import { constrainPulseMetrics } from '../constrainPulseMetrics.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; @@ -44,6 +45,14 @@ Retrieves a list of published Pulse Metrics from a list of metric IDs using the { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + return await listPulseMetricsFromMetricIdsTool.logAndExecute({ requestId, sessionId, @@ -51,19 +60,23 @@ Retrieves a list of published Pulse Metrics from a list of metric IDs using the args: { metricIds }, callback: async () => { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:insight_metrics:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.pulseMethods.listPulseMetricsFromMetricIds(metricIds); }, }); }, - constrainSuccessResult: (metrics) => - constrainPulseMetrics({ metrics, boundedContext: config.boundedContext }), + constrainSuccessResult: async (metrics) => { + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + + return constrainPulseMetrics({ + metrics, + boundedContext: configWithOverrides.boundedContext, + }); + }, getErrorText: getPulseDisabledError, productTelemetryBase: { endpoint: config.productTelemetryEndpoint, diff --git a/src/tools/queryDatasource/queryDatasource.ts b/src/tools/queryDatasource/queryDatasource.ts index 89ac82db..fdff9326 100644 --- a/src/tools/queryDatasource/queryDatasource.ts +++ b/src/tools/queryDatasource/queryDatasource.ts @@ -16,6 +16,7 @@ import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { TableauAuthInfo } from '../../server/oauth/schemas.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; import { getResultForTableauVersion } from '../../utils/isTableauVersionAtLeast.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { Provider } from '../../utils/provider.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; import { resourceAccessChecker } from '../resourceAccessChecker.js'; @@ -56,7 +57,6 @@ export const getQueryDatasourceTool = ( authInfo?: TableauAuthInfo, ): Tool => { const config = getConfig(); - const queryDatasourceTool = new Tool({ server, name: 'query-datasource', @@ -81,6 +81,18 @@ export const getQueryDatasourceTool = ( { datasourceLuid, query, limit }, { requestId, sessionId, authInfo, signal }, ): Promise => { + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + return await queryDatasourceTool.logAndExecute({ requestId, sessionId, @@ -89,7 +101,7 @@ export const getQueryDatasourceTool = ( callback: async () => { const isDatasourceAllowedResult = await resourceAccessChecker.isDatasourceAllowed({ datasourceLuid, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }); if (!isDatasourceAllowedResult.allowed) { @@ -100,7 +112,7 @@ export const getQueryDatasourceTool = ( } const datasource: Datasource = { datasourceLuid }; - const maxResultLimit = config.getMaxResultLimit(queryDatasourceTool.name); + const maxResultLimit = configWithOverrides.getMaxResultLimit(queryDatasourceTool.name); const rowLimit = maxResultLimit ? Math.min(maxResultLimit, limit ?? Number.MAX_SAFE_INTEGER) : limit; @@ -134,14 +146,10 @@ export const getQueryDatasourceTool = ( }; return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:viz_data_service:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - if (!config.disableQueryDatasourceValidationRequests) { + if (!configWithOverrides.disableQueryDatasourceValidationRequests) { // Validate query against metadata const metadataValidationResult = await validateQueryAgainstDatasourceMetadata( query, diff --git a/src/tools/resourceAccessChecker.ts b/src/tools/resourceAccessChecker.ts index b0bb9e79..7407a6c9 100644 --- a/src/tools/resourceAccessChecker.ts +++ b/src/tools/resourceAccessChecker.ts @@ -1,29 +1,22 @@ -import { RequestId } from '@modelcontextprotocol/sdk/types.js'; - -import { BoundedContext, Config, getConfig } from '../config.js'; -import { useRestApi } from '../restApiInstance.js'; +import { BoundedContext } from '../overridableConfig.js'; +import { RestApiArgs, useRestApi } from '../restApiInstance.js'; import { DataSource } from '../sdks/tableau/types/dataSource.js'; import { View } from '../sdks/tableau/types/view.js'; import { Workbook } from '../sdks/tableau/types/workbook.js'; -import { Server } from '../server.js'; import { getExceptionMessage } from '../utils/getExceptionMessage.js'; +import { getConfigWithOverrides } from '../utils/mcpSiteSettings.js'; type AllowedResult = | { allowed: true; content?: T } | { allowed: false; message: string }; -export type RestApiArgs = { - config: Config; - requestId: RequestId; - server: Server; - signal: AbortSignal; -}; - class ResourceAccessChecker { - private _allowedProjectIds: Set | null | undefined; - private _allowedDatasourceIds: Set | null | undefined; - private _allowedWorkbookIds: Set | null | undefined; - private _allowedTags: Set | null | undefined; + private _testOverrides: { + projectIds: Set | null | undefined; + datasourceIds: Set | null | undefined; + workbookIds: Set | null | undefined; + tags: Set | null | undefined; + }; private readonly _cachedDatasourceIds: Map; private readonly _cachedWorkbookIds: Map>; @@ -38,48 +31,78 @@ class ResourceAccessChecker { } // Optional bounded context to use for testing. - private constructor(boundedContext?: BoundedContext) { + private constructor(testOverrides?: BoundedContext) { // The methods assume these sets are non-empty. - this._allowedProjectIds = boundedContext?.projectIds; - this._allowedDatasourceIds = boundedContext?.datasourceIds; - this._allowedWorkbookIds = boundedContext?.workbookIds; - this._allowedTags = boundedContext?.tags; + this._testOverrides = { + projectIds: testOverrides?.projectIds, + datasourceIds: testOverrides?.datasourceIds, + workbookIds: testOverrides?.workbookIds, + tags: testOverrides?.tags, + }; this._cachedDatasourceIds = new Map(); this._cachedWorkbookIds = new Map(); this._cachedViewIds = new Map(); } - private get allowedProjectIds(): Set | null { - if (this._allowedProjectIds === undefined) { - this._allowedProjectIds = getConfig().boundedContext.projectIds; - } - - return this._allowedProjectIds; + private async getAllowedProjectIds({ + restApiArgs, + }: { + restApiArgs: RestApiArgs; + }): Promise | null> { + return ( + this._testOverrides.projectIds ?? + ( + await getConfigWithOverrides({ + restApiArgs, + }) + ).boundedContext.projectIds + ); } - private get allowedDatasourceIds(): Set | null { - if (this._allowedDatasourceIds === undefined) { - this._allowedDatasourceIds = getConfig().boundedContext.datasourceIds; - } - - return this._allowedDatasourceIds; + private async getAllowedDatasourceIds({ + restApiArgs, + }: { + restApiArgs: RestApiArgs; + }): Promise | null> { + return ( + this._testOverrides.datasourceIds ?? + ( + await getConfigWithOverrides({ + restApiArgs, + }) + ).boundedContext.datasourceIds + ); } - private get allowedWorkbookIds(): Set | null { - if (this._allowedWorkbookIds === undefined) { - this._allowedWorkbookIds = getConfig().boundedContext.workbookIds; - } - - return this._allowedWorkbookIds; + private async getAllowedWorkbookIds({ + restApiArgs, + }: { + restApiArgs: RestApiArgs; + }): Promise | null> { + return ( + this._testOverrides.workbookIds ?? + ( + await getConfigWithOverrides({ + restApiArgs, + }) + ).boundedContext.workbookIds + ); } - private get allowedTags(): Set | null { - if (this._allowedTags === undefined) { - this._allowedTags = getConfig().boundedContext.tags; - } - - return this._allowedTags; + private async getAllowedTags({ + restApiArgs, + }: { + restApiArgs: RestApiArgs; + }): Promise | null> { + return ( + this._testOverrides.tags ?? + ( + await getConfigWithOverrides({ + restApiArgs, + }) + ).boundedContext.tags + ); } async isDatasourceAllowed({ @@ -94,7 +117,9 @@ class ResourceAccessChecker { restApiArgs, }); - if (!this.allowedProjectIds && !this.allowedTags) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (!allowedProjectIds && !allowedTags) { // If project filtering is enabled, we cannot cache the result since the datasource may be moved between projects. // If tag filtering is enabled, we cannot cache the result since the datasource tags can change over time. this._cachedDatasourceIds.set(datasourceLuid, result); @@ -115,7 +140,9 @@ class ResourceAccessChecker { restApiArgs, }); - if (!this.allowedProjectIds && !this.allowedTags) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (!allowedProjectIds && !allowedTags) { // If project filtering is enabled, we cannot cache the result since the workbook may be moved between projects. // If tag filtering is enabled, we cannot cache the result since the workbook tags can change over time. this._cachedWorkbookIds.set(workbookId, result); @@ -136,7 +163,9 @@ class ResourceAccessChecker { restApiArgs, }); - if (!this.allowedProjectIds && !this.allowedTags) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (!allowedProjectIds && !allowedTags) { // If project filtering is enabled, we cannot cache the result since the workbook containing the view may be moved between projects. // If tag filtering is enabled, we cannot cache the result since the view tags can change over time. this._cachedViewIds.set(viewId, result); @@ -147,7 +176,7 @@ class ResourceAccessChecker { private async _isDatasourceAllowed({ datasourceLuid, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }: { datasourceLuid: string; restApiArgs: RestApiArgs; @@ -157,7 +186,8 @@ class ResourceAccessChecker { return cachedResult; } - if (this.allowedDatasourceIds && !this.allowedDatasourceIds.has(datasourceLuid)) { + const allowedDatasourceIds = await this.getAllowedDatasourceIds({ restApiArgs }); + if (allowedDatasourceIds && !allowedDatasourceIds.has(datasourceLuid)) { return { allowed: false, message: [ @@ -170,11 +200,8 @@ class ResourceAccessChecker { let datasource: DataSource | undefined; async function getDatasource(): Promise { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, callback: async (restApi) => await restApi.datasourcesMethods.queryDatasource({ siteId: restApi.siteId, @@ -183,11 +210,12 @@ class ResourceAccessChecker { }); } - if (this.allowedProjectIds) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + if (allowedProjectIds) { try { datasource = await getDatasource(); - if (!this.allowedProjectIds.has(datasource.project.id)) { + if (!allowedProjectIds.has(datasource.project.id)) { return { allowed: false, message: [ @@ -208,11 +236,12 @@ class ResourceAccessChecker { } } - if (this.allowedTags) { + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (allowedTags) { try { datasource = datasource ?? (await getDatasource()); - if (!datasource.tags?.tag?.some((tag) => this.allowedTags?.has(tag.label))) { + if (!datasource.tags?.tag?.some((tag) => allowedTags.has(tag.label))) { return { allowed: false, message: [ @@ -238,7 +267,7 @@ class ResourceAccessChecker { private async _isWorkbookAllowed({ workbookId, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }: { workbookId: string; restApiArgs: RestApiArgs; @@ -248,7 +277,8 @@ class ResourceAccessChecker { return cachedResult; } - if (this.allowedWorkbookIds && !this.allowedWorkbookIds.has(workbookId)) { + const allowedWorkbookIds = await this.getAllowedWorkbookIds({ restApiArgs }); + if (allowedWorkbookIds && !allowedWorkbookIds.has(workbookId)) { return { allowed: false, message: [ @@ -261,11 +291,8 @@ class ResourceAccessChecker { let workbook: Workbook | undefined; async function getWorkbook(): Promise { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, callback: async (restApi) => await restApi.workbooksMethods.getWorkbook({ siteId: restApi.siteId, @@ -274,11 +301,12 @@ class ResourceAccessChecker { }); } - if (this.allowedProjectIds) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + if (allowedProjectIds) { try { workbook = await getWorkbook(); - if (!this.allowedProjectIds.has(workbook.project?.id ?? '')) { + if (!allowedProjectIds.has(workbook.project?.id ?? '')) { return { allowed: false, message: [ @@ -299,11 +327,12 @@ class ResourceAccessChecker { } } - if (this.allowedTags) { + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (allowedTags) { try { workbook = workbook ?? (await getWorkbook()); - if (!workbook.tags?.tag?.some((tag) => this.allowedTags?.has(tag.label))) { + if (!workbook.tags?.tag?.some((tag) => allowedTags.has(tag.label))) { return { allowed: false, message: [ @@ -329,7 +358,7 @@ class ResourceAccessChecker { private async _isViewAllowed({ viewId, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }: { viewId: string; restApiArgs: RestApiArgs; @@ -342,11 +371,8 @@ class ResourceAccessChecker { let view: View | undefined; async function getView(): Promise { return await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, callback: async (restApi) => { return await restApi.viewsMethods.getView({ siteId: restApi.siteId, @@ -356,11 +382,12 @@ class ResourceAccessChecker { }); } - if (this.allowedWorkbookIds) { + const allowedWorkbookIds = await this.getAllowedWorkbookIds({ restApiArgs }); + if (allowedWorkbookIds) { try { view = await getView(); - if (!this.allowedWorkbookIds.has(view.workbook?.id ?? '')) { + if (!allowedWorkbookIds.has(view.workbook?.id ?? '')) { return { allowed: false, message: [ @@ -381,11 +408,12 @@ class ResourceAccessChecker { } } - if (this.allowedProjectIds) { + const allowedProjectIds = await this.getAllowedProjectIds({ restApiArgs }); + if (allowedProjectIds) { try { view = view ?? (await getView()); - if (!this.allowedProjectIds.has(view.project?.id ?? '')) { + if (!allowedProjectIds.has(view.project?.id ?? '')) { return { allowed: false, message: [ @@ -406,11 +434,12 @@ class ResourceAccessChecker { } } - if (this.allowedTags) { + const allowedTags = await this.getAllowedTags({ restApiArgs }); + if (allowedTags) { try { view = view ?? (await getView()); - if (!view.tags?.tag?.some((tag) => this.allowedTags?.has(tag.label))) { + if (!view.tags?.tag?.some((tag) => allowedTags.has(tag.label))) { return { allowed: false, message: [ diff --git a/src/tools/views/getViewData.ts b/src/tools/views/getViewData.ts index d75d0792..68f5991c 100644 --- a/src/tools/views/getViewData.ts +++ b/src/tools/views/getViewData.ts @@ -36,6 +36,13 @@ export const getGetViewDataTool = (server: Server): Tool => { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; return await getViewDataTool.logAndExecute({ requestId, @@ -45,7 +52,7 @@ export const getGetViewDataTool = (server: Server): Tool => callback: async () => { const isViewAllowedResult = await resourceAccessChecker.isViewAllowed({ viewId, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }); if (!isViewAllowedResult.allowed) { @@ -57,12 +64,8 @@ export const getGetViewDataTool = (server: Server): Tool => return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:views:download'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.viewsMethods.queryViewData({ viewId, diff --git a/src/tools/views/getViewImage.ts b/src/tools/views/getViewImage.ts index 0b77771d..affc1ab1 100644 --- a/src/tools/views/getViewImage.ts +++ b/src/tools/views/getViewImage.ts @@ -39,6 +39,13 @@ export const getGetViewImageTool = (server: Server): Tool = { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; return await getViewImageTool.logAndExecute({ requestId, @@ -48,7 +55,7 @@ export const getGetViewImageTool = (server: Server): Tool = callback: async () => { const isViewAllowedResult = await resourceAccessChecker.isViewAllowed({ viewId, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }); if (!isViewAllowedResult.allowed) { @@ -60,12 +67,8 @@ export const getGetViewImageTool = (server: Server): Tool = return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:views:download'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { return await restApi.viewsMethods.queryViewImage({ viewId, diff --git a/src/tools/views/listViews.ts b/src/tools/views/listViews.ts index e8166b16..a0d05c95 100644 --- a/src/tools/views/listViews.ts +++ b/src/tools/views/listViews.ts @@ -2,12 +2,14 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Ok } from 'ts-results-es'; import { z } from 'zod'; -import { BoundedContext, getConfig } from '../../config.js'; +import { getConfig } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { useRestApi } from '../../restApiInstance.js'; import { View } from '../../sdks/tableau/types/view.js'; import { Server } from '../../server.js'; import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { paginate } from '../../utils/paginate.js'; import { genericFilterDescription } from '../genericFilterDescription.js'; import { ConstrainedResult, Tool } from '../tool.js'; @@ -72,6 +74,18 @@ export const getListViewsTool = (server: Server): Tool => { { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + const validatedFilter = filter ? parseAndValidateViewsFilterString(filter) : undefined; return await listViewsTool.logAndExecute({ @@ -82,14 +96,10 @@ export const getListViewsTool = (server: Server): Tool => { callback: async () => { return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - const maxResultLimit = config.getMaxResultLimit(listViewsTool.name); + const maxResultLimit = configWithOverrides.getMaxResultLimit(listViewsTool.name); const views = await paginate({ pageConfig: { pageSize, @@ -117,7 +127,7 @@ export const getListViewsTool = (server: Server): Tool => { ); }, constrainSuccessResult: (views) => - constrainViews({ views, boundedContext: config.boundedContext }), + constrainViews({ views, boundedContext: configWithOverrides.boundedContext }), productTelemetryBase: { endpoint: config.productTelemetryEndpoint, siteLuid: getSiteLuidFromAccessToken(getTableauAuthInfo(authInfo)?.accessToken), diff --git a/src/tools/workbooks/getWorkbook.ts b/src/tools/workbooks/getWorkbook.ts index 84cc9617..5b39caeb 100644 --- a/src/tools/workbooks/getWorkbook.ts +++ b/src/tools/workbooks/getWorkbook.ts @@ -2,12 +2,14 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Err, Ok } from 'ts-results-es'; import { z } from 'zod'; -import { BoundedContext, getConfig } from '../../config.js'; +import { getConfig } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { useRestApi } from '../../restApiInstance.js'; import { Workbook } from '../../sdks/tableau/types/workbook.js'; import { Server } from '../../server.js'; import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { ConstrainedResult, Tool } from '../tool.js'; @@ -37,6 +39,17 @@ export const getGetWorkbookTool = (server: Server): Tool => { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); return await getWorkbookTool.logAndExecute({ requestId, @@ -46,7 +59,7 @@ export const getGetWorkbookTool = (server: Server): Tool => callback: async () => { const isWorkbookAllowedResult = await resourceAccessChecker.isWorkbookAllowed({ workbookId, - restApiArgs: { config, requestId, server, signal }, + restApiArgs, }); if (!isWorkbookAllowedResult.allowed) { @@ -58,12 +71,8 @@ export const getGetWorkbookTool = (server: Server): Tool => return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { // Notice that we already have the workbook if it had been allowed by a project scope. const workbook = @@ -91,7 +100,7 @@ export const getGetWorkbookTool = (server: Server): Tool => ); }, constrainSuccessResult: (workbook) => - filterWorkbookViews({ workbook, boundedContext: config.boundedContext }), + filterWorkbookViews({ workbook, boundedContext: configWithOverrides.boundedContext }), getErrorText: (error: GetWorkbookError) => { switch (error.type) { case 'workbook-not-allowed': diff --git a/src/tools/workbooks/listWorkbooks.ts b/src/tools/workbooks/listWorkbooks.ts index f4803268..16b25743 100644 --- a/src/tools/workbooks/listWorkbooks.ts +++ b/src/tools/workbooks/listWorkbooks.ts @@ -2,12 +2,14 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Ok } from 'ts-results-es'; import { z } from 'zod'; -import { BoundedContext, getConfig } from '../../config.js'; +import { getConfig } from '../../config.js'; +import { BoundedContext } from '../../overridableConfig.js'; import { useRestApi } from '../../restApiInstance.js'; import { Workbook } from '../../sdks/tableau/types/workbook.js'; import { Server } from '../../server.js'; import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; import { getSiteLuidFromAccessToken } from '../../utils/getSiteLuidFromAccessToken.js'; +import { getConfigWithOverrides } from '../../utils/mcpSiteSettings.js'; import { paginate } from '../../utils/paginate.js'; import { genericFilterDescription } from '../genericFilterDescription.js'; import { ConstrainedResult, Tool } from '../tool.js'; @@ -69,6 +71,18 @@ export const getListWorkbooksTool = (server: Server): Tool { requestId, sessionId, authInfo, signal }, ): Promise => { const config = getConfig(); + const restApiArgs = { + config, + requestId, + server, + signal, + authInfo: getTableauAuthInfo(authInfo), + }; + + const configWithOverrides = await getConfigWithOverrides({ + restApiArgs, + }); + const validatedFilter = filter ? parseAndValidateWorkbooksFilterString(filter) : undefined; return await listWorkbooksTool.logAndExecute({ @@ -79,14 +93,13 @@ export const getListWorkbooksTool = (server: Server): Tool callback: async () => { return new Ok( await useRestApi({ - config, - requestId, - server, + ...restApiArgs, jwtScopes: ['tableau:content:read'], - signal, - authInfo: getTableauAuthInfo(authInfo), callback: async (restApi) => { - const maxResultLimit = config.getMaxResultLimit(listWorkbooksTool.name); + const maxResultLimit = configWithOverrides.getMaxResultLimit( + listWorkbooksTool.name, + ); + const workbooks = await paginate({ pageConfig: { pageSize, @@ -113,7 +126,7 @@ export const getListWorkbooksTool = (server: Server): Tool ); }, constrainSuccessResult: (workbooks) => - constrainWorkbooks({ workbooks, boundedContext: config.boundedContext }), + constrainWorkbooks({ workbooks, boundedContext: configWithOverrides.boundedContext }), productTelemetryBase: { endpoint: config.productTelemetryEndpoint, siteLuid: getSiteLuidFromAccessToken(getTableauAuthInfo(authInfo)?.accessToken), diff --git a/src/utils/mcpSiteSettings.test.ts b/src/utils/mcpSiteSettings.test.ts new file mode 100644 index 00000000..af18ec9d --- /dev/null +++ b/src/utils/mcpSiteSettings.test.ts @@ -0,0 +1,112 @@ +import { Server } from '../server'; +import { stubDefaultEnvVars } from '../testShared'; +import { getConfigWithOverrides } from './mcpSiteSettings'; + +const mocks = vi.hoisted(() => ({ + mockGetMcpSiteSettings: vi.fn(), +})); + +vi.mock('../restApiInstance.js', () => ({ + useRestApi: vi.fn().mockImplementation(async ({ callback }) => + callback({ + siteMethods: { + getMcpSettings: mocks.mockGetMcpSiteSettings, + }, + }), + ), +})); + +describe('mcpSiteSettings', () => { + beforeEach(() => { + vi.unstubAllEnvs(); + stubDefaultEnvVars(); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it('should not override any settings when enableMcpSiteSettings is false', async () => { + vi.stubEnv('ENABLE_MCP_SITE_SETTINGS', 'false'); + const config = await getConfigWithOverrides({ + restApiArgs: { + server: new Server(), + disableLogging: true, + }, + }); + + expect(config.includeTools).toEqual([]); + expect(config.excludeTools).toEqual([]); + expect(config.boundedContext).toEqual({ + projectIds: null, + datasourceIds: null, + workbookIds: null, + tags: null, + }); + expect(config.getMaxResultLimit('query-datasource')).toEqual(null); + expect(config.disableQueryDatasourceValidationRequests).toEqual(false); + expect(config.disableMetadataApiRequests).toEqual(false); + + expect(mocks.mockGetMcpSiteSettings).not.toHaveBeenCalled(); + }); + + it('should override settings when enableMcpSiteSettings is true', async () => { + vi.stubEnv('ENABLE_MCP_SITE_SETTINGS', 'true'); + mocks.mockGetMcpSiteSettings.mockResolvedValue({ + INCLUDE_TOOLS: 'list-views,list-datasources', + INCLUDE_PROJECT_IDS: 'project1,project2', + INCLUDE_DATASOURCE_IDS: 'datasource1,datasource2', + INCLUDE_WORKBOOK_IDS: 'workbook1,workbook2', + INCLUDE_TAGS: 'tag1,tag2', + MAX_RESULT_LIMIT: '100', + MAX_RESULT_LIMITS: 'query-datasource:100,list-datasources:20', + DISABLE_QUERY_DATASOURCE_VALIDATION_REQUESTS: 'true', + DISABLE_METADATA_API_REQUESTS: 'true', + }); + + let config = await getConfigWithOverrides({ + restApiArgs: { + server: new Server(), + disableLogging: true, + }, + }); + + expect(config.includeTools).toEqual(['list-views', 'list-datasources']); + expect(config.excludeTools).toEqual([]); + expect(config.boundedContext).toEqual({ + projectIds: new Set(['project1', 'project2']), + datasourceIds: new Set(['datasource1', 'datasource2']), + workbookIds: new Set(['workbook1', 'workbook2']), + tags: new Set(['tag1', 'tag2']), + }); + expect(config.getMaxResultLimit('query-datasource')).toEqual(100); + expect(config.getMaxResultLimit('list-datasources')).toEqual(20); + expect(config.disableQueryDatasourceValidationRequests).toEqual(true); + expect(config.disableMetadataApiRequests).toEqual(true); + + expect(mocks.mockGetMcpSiteSettings).toHaveBeenCalledTimes(1); + + // Verify cache behavior + config = await getConfigWithOverrides({ + restApiArgs: { + server: new Server(), + disableLogging: true, + }, + }); + + expect(config.includeTools).toEqual(['list-views', 'list-datasources']); + expect(config.excludeTools).toEqual([]); + expect(config.boundedContext).toEqual({ + projectIds: new Set(['project1', 'project2']), + datasourceIds: new Set(['datasource1', 'datasource2']), + workbookIds: new Set(['workbook1', 'workbook2']), + tags: new Set(['tag1', 'tag2']), + }); + expect(config.getMaxResultLimit('query-datasource')).toEqual(100); + expect(config.getMaxResultLimit('list-datasources')).toEqual(20); + expect(config.disableQueryDatasourceValidationRequests).toEqual(true); + expect(config.disableMetadataApiRequests).toEqual(true); + + expect(mocks.mockGetMcpSiteSettings).toHaveBeenCalledTimes(1); + }); +}); diff --git a/src/utils/mcpSiteSettings.ts b/src/utils/mcpSiteSettings.ts new file mode 100644 index 00000000..08f85898 --- /dev/null +++ b/src/utils/mcpSiteSettings.ts @@ -0,0 +1,65 @@ +import { Config, getConfig } from '../config.js'; +import { getOverridableConfig, OverridableConfig } from '../overridableConfig.js'; +import { RestApiArgs, useRestApi } from '../restApiInstance.js'; +import { McpSiteSettings } from '../sdks/tableau/types/mcpSiteSettings.js'; +import { ExpiringMap } from './expiringMap.js'; +import { getSiteLuidFromAccessToken } from './getSiteLuidFromAccessToken.js'; +import { DistributiveOmit } from './types.js'; + +type SiteNameOrSiteId = string; +let mcpSiteSettingsCache: ExpiringMap; + +async function getMcpSiteSettings({ + restApiArgs, +}: { + restApiArgs: RestApiArgs; +}): Promise { + const { config, authInfo } = restApiArgs; + if (!config.enableMcpSiteSettings) { + return; + } + + if (!mcpSiteSettingsCache) { + mcpSiteSettingsCache = new ExpiringMap({ + defaultExpirationTimeMs: config.mcpSiteSettingsCheckIntervalInMinutes * 60 * 1000, + }); + } + + const cacheKey = config.siteName || getSiteLuidFromAccessToken(authInfo?.accessToken); + if (!cacheKey) { + throw new Error('Could not determine site ID/name'); + } + + const cachedSettings = mcpSiteSettingsCache.get(cacheKey); + if (cachedSettings) { + return cachedSettings; + } + + const settings = await useRestApi({ + ...restApiArgs, + jwtScopes: ['tableau:mcp_site_settings:read'], + callback: async (restApi) => await restApi.siteMethods.getMcpSettings(), + }); + + mcpSiteSettingsCache.set(cacheKey, settings); + return settings; +} + +// Make "config" and "signal" optional +type GetConfigWithOverridesArgs = DistributiveOmit & + Partial<{ config: Config; signal: AbortSignal }>; + +export async function getConfigWithOverrides({ + restApiArgs, +}: { + restApiArgs: GetConfigWithOverridesArgs; +}): Promise { + const config = restApiArgs.config ?? getConfig(); + const signal = restApiArgs.signal ?? AbortSignal.timeout(config.maxRequestTimeoutMs); + + const overrides = await getMcpSiteSettings({ + restApiArgs: { ...restApiArgs, config, signal }, + }); + + return getOverridableConfig(overrides); +} diff --git a/src/utils/types.ts b/src/utils/types.ts new file mode 100644 index 00000000..471cc536 --- /dev/null +++ b/src/utils/types.ts @@ -0,0 +1,3 @@ +// "Omit" is not distributive over unions. +// The fix is to wrap Omit in a distributive conditional type so it is applied to each union member individually. +export type DistributiveOmit = T extends any ? Omit : never; diff --git a/types/process-env.d.ts b/types/process-env.d.ts index 213c8cce..7ba7ddc8 100644 --- a/types/process-env.d.ts +++ b/types/process-env.d.ts @@ -40,6 +40,8 @@ export interface ProcessEnvEx { INCLUDE_WORKBOOK_IDS: string | undefined; INCLUDE_TAGS: string | undefined; TABLEAU_SERVER_VERSION_CHECK_INTERVAL_IN_HOURS: string | undefined; + MCP_SITE_SETTINGS_CHECK_INTERVAL_IN_MINUTES: string | undefined; + ENABLE_MCP_SITE_SETTINGS: string | undefined; DANGEROUSLY_DISABLE_OAUTH: string | undefined; OAUTH_ISSUER: string | undefined; OAUTH_JWE_PRIVATE_KEY: string | undefined;