diff --git a/src/apiClients.ts b/src/apiClients.ts new file mode 100644 index 00000000..8f8ac423 --- /dev/null +++ b/src/apiClients.ts @@ -0,0 +1,186 @@ +import { RequestId } from '@modelcontextprotocol/sdk/types.js'; + +import { getConfig } from './config.js'; +import { log, shouldLogWhenLevelIsAtLeast } from './logging/log.js'; +import { maskRequest, maskResponse } from './logging/secretMask.js'; +import { + AxiosInterceptor, + AxiosResponseInterceptorConfig, + ErrorInterceptor, + getRequestInterceptorConfig, + getResponseInterceptorConfig, + RequestInterceptor, + RequestInterceptorConfig, + ResponseInterceptor, + ResponseInterceptorConfig, +} from './sdks/tableau/interceptors.js'; +import { Server, userAgent } from './server.js'; +import { isAxiosError } from './utils/axios.js'; +import { getExceptionMessage } from './utils/getExceptionMessage.js'; + +export const getRequestInterceptor = + (server: Server, requestId: RequestId, logger: string): RequestInterceptor => + (request) => { + request.headers['User-Agent'] = getUserAgent(server); + logRequest(server, request, requestId, logger); + return request; + }; + +export const getRequestErrorInterceptor = + (server: Server, requestId: RequestId, logger: string): ErrorInterceptor => + (error, baseUrl) => { + if (!isAxiosError(error) || !error.request) { + log.error(server, `Request ${requestId} failed with error: ${getExceptionMessage(error)}`, { + logger, + requestId, + }); + return; + } + + const { request } = error; + logRequest( + server, + { + baseUrl, + ...getRequestInterceptorConfig(request), + }, + requestId, + logger, + ); + }; + +export const getResponseInterceptor = + (server: Server, requestId: RequestId, logger: string): ResponseInterceptor => + (response) => { + logResponse(server, response, requestId, logger); + return response; + }; + +export const getResponseErrorInterceptor = + (server: Server, requestId: RequestId, logger: string): ErrorInterceptor => + (error, baseUrl) => { + if (!isAxiosError(error) || !error.response) { + log.error( + server, + `Response from request ${requestId} failed with error: ${getExceptionMessage(error)}`, + { logger, requestId }, + ); + return; + } + + // The type for the AxiosResponse headers is complex and not directly assignable to that of the Axios response interceptor's. + const { response } = error as { response: AxiosResponseInterceptorConfig }; + logResponse( + server, + { + baseUrl, + ...getResponseInterceptorConfig(response), + }, + requestId, + logger, + ); + }; + +function logRequest( + server: Server, + request: RequestInterceptorConfig, + requestId: RequestId, + logger: string, +): void { + const config = getConfig(); + const maskedRequest = config.disableLogMasking ? request : maskRequest(request); + const url = new URL( + `${maskedRequest.baseUrl.replace(/\/$/, '')}/${maskedRequest.url?.replace(/^\//, '') ?? ''}`, + ); + if (request.params && Object.keys(request.params).length > 0) { + url.search = new URLSearchParams(request.params).toString(); + } + + const messageObj = { + type: 'request', + requestId, + method: maskedRequest.method, + url: url.toString(), + ...(shouldLogWhenLevelIsAtLeast('debug') && { + headers: maskedRequest.headers, + data: maskedRequest.data, + params: maskedRequest.params, + }), + } as const; + + log.info(server, messageObj, { logger, requestId }); +} + +function logResponse( + server: Server, + response: ResponseInterceptorConfig, + requestId: RequestId, + logger: string, +): void { + const config = getConfig(); + const maskedResponse = config.disableLogMasking ? response : maskResponse(response); + const url = new URL( + `${maskedResponse.baseUrl.replace(/\/$/, '')}/${maskedResponse.url?.replace(/^\//, '') ?? ''}`, + ); + if (response.request?.params && Object.keys(response.request.params).length > 0) { + url.search = new URLSearchParams(response.request.params).toString(); + } + const messageObj = { + type: 'response', + requestId, + url: url.toString(), + status: maskedResponse.status, + ...(shouldLogWhenLevelIsAtLeast('debug') && { + headers: maskedResponse.headers, + data: maskedResponse.data, + }), + } as const; + + log.info(server, messageObj, { logger, requestId }); +} + +function getUserAgent(server: Server): string { + const userAgentParts = [userAgent]; + if (server.clientInfo) { + const { name, version } = server.clientInfo; + if (name) { + userAgentParts.push(version ? `(${name} ${version})` : `(${name})`); + } + } + return userAgentParts.join(' '); +} + +export const addInterceptors = ( + baseUrl: string, + axiosInterceptors: AxiosInterceptor, + requestInterceptors?: [RequestInterceptor, ErrorInterceptor?], + responseInterceptors?: [ResponseInterceptor, ErrorInterceptor?], +): void => { + axiosInterceptors.request.use( + (config) => { + requestInterceptors?.[0]({ + baseUrl, + ...getRequestInterceptorConfig(config), + }); + return config; + }, + (error) => { + requestInterceptors?.[1]?.(error, baseUrl); + return Promise.reject(error); + }, + ); + + axiosInterceptors.response.use( + (response) => { + responseInterceptors?.[0]({ + baseUrl, + ...getResponseInterceptorConfig(response), + }); + return response; + }, + (error) => { + responseInterceptors?.[1]?.(error, baseUrl); + return Promise.reject(error); + }, + ); +}; diff --git a/src/restApiInstance.test.ts b/src/restApiInstance.test.ts index 9a9c07d2..ffa56fdb 100644 --- a/src/restApiInstance.test.ts +++ b/src/restApiInstance.test.ts @@ -1,14 +1,14 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { getConfig } from './config.js'; -import { log } from './logging/log.js'; import { getRequestErrorInterceptor, getRequestInterceptor, getResponseErrorInterceptor, getResponseInterceptor, - useRestApi, -} from './restApiInstance.js'; +} from './apiClients.js'; +import { getConfig } from './config.js'; +import { log } from './logging/log.js'; +import { useRestApi } from './restApiInstance.js'; import { AuthConfig } from './sdks/tableau/authConfig.js'; import { RestApi } from './sdks/tableau/restApi.js'; import { Server, userAgent } from './server.js'; @@ -55,7 +55,7 @@ describe('restApiInstance', () => { describe('Request Interceptor', () => { it('should add User-Agent header and log request', () => { const server = new Server(); - const interceptor = getRequestInterceptor(server, mockRequestId); + const interceptor = getRequestInterceptor(server, mockRequestId, 'rest-api'); const mockRequest = { headers: {} as Record, method: 'GET', @@ -85,7 +85,7 @@ describe('restApiInstance', () => { describe('Response Interceptor', () => { it('should log response', () => { const server = new Server(); - const interceptor = getResponseInterceptor(server, mockRequestId); + const interceptor = getResponseInterceptor(server, mockRequestId, 'rest-api'); const mockResponse = { status: 200, url: '/api/test', @@ -116,7 +116,7 @@ describe('restApiInstance', () => { describe('Error Handling', () => { it('should handle request errors', () => { const server = new Server(); - const errorInterceptor = getRequestErrorInterceptor(server, mockRequestId); + const errorInterceptor = getRequestErrorInterceptor(server, mockRequestId, 'rest-api'); const mockError = { request: { method: 'GET', @@ -140,7 +140,7 @@ describe('restApiInstance', () => { it('should handle AxiosError request errors', () => { const server = new Server(); - const errorInterceptor = getRequestErrorInterceptor(server, mockRequestId); + const errorInterceptor = getRequestErrorInterceptor(server, mockRequestId, 'rest-api'); const mockError = { isAxiosError: true, request: { @@ -172,7 +172,7 @@ describe('restApiInstance', () => { it('should handle response errors', () => { const server = new Server(); - const errorInterceptor = getResponseErrorInterceptor(server, mockRequestId); + const errorInterceptor = getResponseErrorInterceptor(server, mockRequestId, 'rest-api'); const mockError = { response: { status: 500, @@ -197,7 +197,7 @@ describe('restApiInstance', () => { it('should handle AxiosError response errors', () => { const server = new Server(); - const errorInterceptor = getResponseErrorInterceptor(server, mockRequestId); + const errorInterceptor = getResponseErrorInterceptor(server, mockRequestId, 'rest-api'); const mockError = { isAxiosError: true, response: { diff --git a/src/restApiInstance.ts b/src/restApiInstance.ts index 50ba844d..464e17c5 100644 --- a/src/restApiInstance.ts +++ b/src/restApiInstance.ts @@ -1,23 +1,16 @@ import { RequestId } from '@modelcontextprotocol/sdk/types.js'; -import { Config, getConfig } from './config.js'; -import { log, shouldLogWhenLevelIsAtLeast } from './logging/log.js'; -import { maskRequest, maskResponse } from './logging/secretMask.js'; import { - AxiosResponseInterceptorConfig, - ErrorInterceptor, - getRequestInterceptorConfig, - getResponseInterceptorConfig, - RequestInterceptor, - RequestInterceptorConfig, - ResponseInterceptor, - ResponseInterceptorConfig, -} from './sdks/tableau/interceptors.js'; + getRequestErrorInterceptor, + getRequestInterceptor, + getResponseErrorInterceptor, + getResponseInterceptor, +} from './apiClients.js'; +import { Config } from './config.js'; +import { log } from './logging/log.js'; import { RestApi } from './sdks/tableau/restApi.js'; -import { Server, userAgent } from './server.js'; +import { Server } from './server.js'; import { TableauAuthInfo } from './server/oauth/schemas.js'; -import { isAxiosError } from './utils/axios.js'; -import { getExceptionMessage } from './utils/getExceptionMessage.js'; import invariant from './utils/invariant.js'; type JwtScopes = @@ -61,12 +54,12 @@ const getNewRestApiInstanceAsync = async ( maxRequestTimeoutMs: config.maxRequestTimeoutMs, signal, requestInterceptor: [ - getRequestInterceptor(server, requestId), - getRequestErrorInterceptor(server, requestId), + getRequestInterceptor(server, requestId, 'rest-api'), + getRequestErrorInterceptor(server, requestId, 'rest-api'), ], responseInterceptor: [ - getResponseInterceptor(server, requestId), - getResponseErrorInterceptor(server, requestId), + getResponseInterceptor(server, requestId, 'rest-api'), + getResponseErrorInterceptor(server, requestId, 'rest-api'), ], }); @@ -149,130 +142,6 @@ export const useRestApi = async ({ } }; -export const getRequestInterceptor = - (server: Server, requestId: RequestId): RequestInterceptor => - (request) => { - request.headers['User-Agent'] = getUserAgent(server); - logRequest(server, request, requestId); - return request; - }; - -export const getRequestErrorInterceptor = - (server: Server, requestId: RequestId): ErrorInterceptor => - (error, baseUrl) => { - if (!isAxiosError(error) || !error.request) { - log.error(server, `Request ${requestId} failed with error: ${getExceptionMessage(error)}`, { - logger: 'rest-api', - requestId, - }); - return; - } - - const { request } = error; - logRequest( - server, - { - baseUrl, - ...getRequestInterceptorConfig(request), - }, - requestId, - ); - }; - -export const getResponseInterceptor = - (server: Server, requestId: RequestId): ResponseInterceptor => - (response) => { - logResponse(server, response, requestId); - return response; - }; - -export const getResponseErrorInterceptor = - (server: Server, requestId: RequestId): ErrorInterceptor => - (error, baseUrl) => { - if (!isAxiosError(error) || !error.response) { - log.error( - server, - `Response from request ${requestId} failed with error: ${getExceptionMessage(error)}`, - { logger: 'rest-api', requestId }, - ); - return; - } - - // The type for the AxiosResponse headers is complex and not directly assignable to that of the Axios response interceptor's. - const { response } = error as { response: AxiosResponseInterceptorConfig }; - logResponse( - server, - { - baseUrl, - ...getResponseInterceptorConfig(response), - }, - requestId, - ); - }; - -function logRequest(server: Server, request: RequestInterceptorConfig, requestId: RequestId): void { - const config = getConfig(); - const maskedRequest = config.disableLogMasking ? request : maskRequest(request); - const url = new URL( - `${maskedRequest.baseUrl.replace(/\/$/, '')}/${maskedRequest.url?.replace(/^\//, '') ?? ''}`, - ); - if (request.params && Object.keys(request.params).length > 0) { - url.search = new URLSearchParams(request.params).toString(); - } - - const messageObj = { - type: 'request', - requestId, - method: maskedRequest.method, - url: url.toString(), - ...(shouldLogWhenLevelIsAtLeast('debug') && { - headers: maskedRequest.headers, - data: maskedRequest.data, - params: maskedRequest.params, - }), - } as const; - - log.info(server, messageObj, { logger: 'rest-api', requestId }); -} - -function logResponse( - server: Server, - response: ResponseInterceptorConfig, - requestId: RequestId, -): void { - const config = getConfig(); - const maskedResponse = config.disableLogMasking ? response : maskResponse(response); - const url = new URL( - `${maskedResponse.baseUrl.replace(/\/$/, '')}/${maskedResponse.url?.replace(/^\//, '') ?? ''}`, - ); - if (response.request?.params && Object.keys(response.request.params).length > 0) { - url.search = new URLSearchParams(response.request.params).toString(); - } - const messageObj = { - type: 'response', - requestId, - url: url.toString(), - status: maskedResponse.status, - ...(shouldLogWhenLevelIsAtLeast('debug') && { - headers: maskedResponse.headers, - data: maskedResponse.data, - }), - } as const; - - log.info(server, messageObj, { logger: 'rest-api', requestId }); -} - -function getUserAgent(server: Server): string { - const userAgentParts = [userAgent]; - if (server.clientInfo) { - const { name, version } = server.clientInfo; - if (name) { - userAgentParts.push(version ? `(${name} ${version})` : `(${name})`); - } - } - return userAgentParts.join(' '); -} - function getJwtUsername(config: Config, authInfo: TableauAuthInfo | undefined): string { return config.jwtUsername.replaceAll('{OAUTH_USERNAME}', authInfo?.username ?? ''); } diff --git a/src/sdks/plugins/headerExtractorPlugin.ts b/src/sdks/plugins/headerExtractorPlugin.ts new file mode 100644 index 00000000..58d2d783 --- /dev/null +++ b/src/sdks/plugins/headerExtractorPlugin.ts @@ -0,0 +1,71 @@ +import { ZodiosEndpointDefinitions, ZodiosInstance, ZodiosPlugin } from '@zodios/core'; + +import { Deferred } from '../../../tests/oauth/deferred'; +import { AxiosResponse, getStringResponseHeader } from '../../utils/axios'; + +type HeaderExtractorOptions = { + headerName: string; + onHeader: (value: string, response: AxiosResponse) => void; +}; + +const HEADER_EXTRACTOR_PLUGIN_NAME = 'header-extractor'; + +const headerExtractorPlugin = ({ headerName, onHeader }: HeaderExtractorOptions): ZodiosPlugin => { + return { + name: HEADER_EXTRACTOR_PLUGIN_NAME, + response: async (_api, _config, response) => { + const headerValue = getStringResponseHeader(response.headers, headerName); + onHeader(headerValue, response); + return response; + }, + }; +}; + +export async function useHeaderExtractorPlugin({ + client, + headerName, + clientCallback, + timeoutMs, + signal, +}: { + client: ZodiosInstance; + headerName: string; + clientCallback: (client: ZodiosInstance) => Promise; + timeoutMs?: number; + signal?: AbortSignal; +}): Promise<{ result: TReturn; headerValue: string }> { + const deferredHeader = new Deferred(); + + let timeoutId: NodeJS.Timeout | undefined; + let abortListener: (() => void) | undefined; + + if (timeoutMs !== undefined) { + timeoutId = setTimeout(() => deferredHeader.resolve(''), timeoutMs); + } + + if (signal) { + abortListener = () => deferredHeader.resolve(''); + signal.addEventListener('abort', abortListener); + } + + try { + client.use( + headerExtractorPlugin({ headerName, onHeader: (value) => deferredHeader.resolve(value) }), + ); + + const result = await clientCallback(client); + const headerValue = await deferredHeader.promise; + + return { result, headerValue }; + } finally { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + + if (signal && abortListener) { + signal.removeEventListener('abort', abortListener); + } + + client.eject(HEADER_EXTRACTOR_PLUGIN_NAME); + } +} diff --git a/src/sdks/tableau-vizql/apis.ts b/src/sdks/tableau-vizql/apis.ts new file mode 100644 index 00000000..2694df7c --- /dev/null +++ b/src/sdks/tableau-vizql/apis.ts @@ -0,0 +1,31 @@ +import { makeApi, makeEndpoint, ZodiosEndpointDefinitions } from '@zodios/core'; +import { z } from 'zod'; + +export const startSessionEndpoint = makeEndpoint({ + method: 'post', + path: '/vizql/t/:siteName/w/:workbookName/v/:viewName/startSession/viewing', + alias: 'startSession', + parameters: [ + { + name: 'siteName', + type: 'Path', + schema: z.string(), + }, + { + name: 'workbookName', + type: 'Path', + schema: z.string(), + }, + { + name: 'viewName', + type: 'Path', + schema: z.string(), + }, + ], + response: z.object({ + sessionid: z.string(), + }), +}); + +const vizqlApi = makeApi([startSessionEndpoint]); +export const vizqlApis = [...vizqlApi] as const satisfies ZodiosEndpointDefinitions; diff --git a/src/sdks/tableau-vizql/client.ts b/src/sdks/tableau-vizql/client.ts new file mode 100644 index 00000000..27384308 --- /dev/null +++ b/src/sdks/tableau-vizql/client.ts @@ -0,0 +1,10 @@ +import { Zodios, ZodiosInstance } from '@zodios/core'; + +import { AxiosRequestConfig } from '../../utils/axios.js'; +import { vizqlApis } from './apis.js'; + +export const getClient = (basePath: string, axiosConfig: AxiosRequestConfig): VizqlClient => { + return new Zodios(basePath, vizqlApis, { axiosConfig }); +}; + +export type VizqlClient = ZodiosInstance; diff --git a/src/sdks/tableau/restApi.ts b/src/sdks/tableau/restApi.ts index 152baceb..5a78b982 100644 --- a/src/sdks/tableau/restApi.ts +++ b/src/sdks/tableau/restApi.ts @@ -1,9 +1,8 @@ +import { addInterceptors } from '../../apiClients.js'; import { AuthConfig } from './authConfig.js'; import { AxiosInterceptor, ErrorInterceptor, - getRequestInterceptorConfig, - getResponseInterceptorConfig, RequestInterceptor, ResponseInterceptor, } from './interceptors.js'; @@ -68,7 +67,7 @@ export class RestApi { this._responseInterceptor = options.responseInterceptor; } - private get creds(): Credentials { + get creds(): Credentials { if (!this._creds) { throw new Error('No credentials found. Authenticate by calling signIn() first.'); } @@ -249,33 +248,12 @@ export class RestApi { }; }; - private _addInterceptors = (baseUrl: string, interceptors: AxiosInterceptor): void => { - interceptors.request.use( - (config) => { - this._requestInterceptor?.[0]({ - baseUrl, - ...getRequestInterceptorConfig(config), - }); - return config; - }, - (error) => { - this._requestInterceptor?.[1]?.(error, baseUrl); - return Promise.reject(error); - }, - ); - - interceptors.response.use( - (response) => { - this._responseInterceptor?.[0]({ - baseUrl, - ...getResponseInterceptorConfig(response), - }); - return response; - }, - (error) => { - this._responseInterceptor?.[1]?.(error, baseUrl); - return Promise.reject(error); - }, + private _addInterceptors = (baseUrl: string, axiosInterceptors: AxiosInterceptor): void => { + addInterceptors( + baseUrl, + axiosInterceptors, + this._requestInterceptor, + this._responseInterceptor, ); }; } diff --git a/src/tools/toolName.ts b/src/tools/toolName.ts index 137e95a4..a120e6be 100644 --- a/src/tools/toolName.ts +++ b/src/tools/toolName.ts @@ -15,6 +15,7 @@ export const toolNames = [ 'generate-pulse-metric-value-insight-bundle', 'generate-pulse-insight-brief', 'search-content', + 'create-workbook-session', ] as const; export type ToolName = (typeof toolNames)[number]; @@ -29,7 +30,7 @@ export type ToolGroupName = (typeof toolGroupNames)[number]; export const toolGroups = { datasource: ['list-datasources', 'get-datasource-metadata', 'query-datasource'], - workbook: ['list-workbooks', 'get-workbook'], + workbook: ['list-workbooks', 'get-workbook', 'create-workbook-session'], view: ['list-views', 'get-view-data', 'get-view-image'], pulse: [ 'list-all-pulse-metric-definitions', diff --git a/src/tools/tools.ts b/src/tools/tools.ts index 9dbc212c..917e67b9 100644 --- a/src/tools/tools.ts +++ b/src/tools/tools.ts @@ -12,6 +12,7 @@ import { getQueryDatasourceTool } from './queryDatasource/queryDatasource.js'; import { getGetViewDataTool } from './views/getViewData.js'; import { getGetViewImageTool } from './views/getViewImage.js'; import { getListViewsTool } from './views/listViews.js'; +import { getCreateWorkbookSessionTool } from './workbooks/createWorkbookSession.js'; import { getGetWorkbookTool } from './workbooks/getWorkbook.js'; import { getListWorkbooksTool } from './workbooks/listWorkbooks.js'; @@ -32,4 +33,5 @@ export const toolFactories = [ getListWorkbooksTool, getListViewsTool, getSearchContentTool, + getCreateWorkbookSessionTool, ]; diff --git a/src/tools/workbooks/createWorkbookSession.ts b/src/tools/workbooks/createWorkbookSession.ts new file mode 100644 index 00000000..570571ac --- /dev/null +++ b/src/tools/workbooks/createWorkbookSession.ts @@ -0,0 +1,149 @@ +import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import { Err, Ok } from 'ts-results-es'; +import { z } from 'zod'; + +import { getConfig } from '../../config.js'; +import { useRestApi } from '../../restApiInstance.js'; +import { useHeaderExtractorPlugin } from '../../sdks/plugins/headerExtractorPlugin.js'; +import { Server } from '../../server.js'; +import { getTableauAuthInfo } from '../../server/oauth/getTableauAuthInfo.js'; +import { getNewVizqlApiInstanceAsync } from '../../vizqlApiInstance.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; +import { Tool } from '../tool.js'; + +const paramsSchema = { + workbookId: z.string().describe('The ID of the workbook to create a session for.'), + viewId: z + .string() + .optional() + .describe( + 'The ID of the view to create a session for. If not provided, the default view of the workbook will be used.', + ), +}; + +export type CreateWorkbookSessionError = { + type: 'workbook-not-allowed' | 'view-not-found'; + message: string; +}; + +export const getCreateWorkbookSessionTool = (server: Server): Tool => { + const createWorkbookSessionTool = new Tool({ + server, + name: 'create-workbook-session', + description: + 'Creates a session for the specified workbook. If a view ID is provided, the session will be created for the specified view. If no view ID is provided, the session will be created for the default view of the workbook.', + paramsSchema, + annotations: { + title: 'Create Workbook Session', + readOnlyHint: true, + openWorldHint: false, + }, + callback: async ( + { workbookId, viewId }, + { requestId, authInfo, signal }, + ): Promise => { + const config = getConfig(); + + return await createWorkbookSessionTool.logAndExecute< + { sessionid: string; globalSessionHeader: string | null }, + CreateWorkbookSessionError + >({ + requestId, + authInfo, + args: { workbookId, viewId }, + callback: async () => { + const isWorkbookAllowedResult = await resourceAccessChecker.isWorkbookAllowed({ + workbookId, + restApiArgs: { config, requestId, server, signal }, + }); + + if (!isWorkbookAllowedResult.allowed) { + return new Err({ + type: 'workbook-not-allowed', + message: isWorkbookAllowedResult.message, + }); + } + + const result = await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:content:read'], + signal, + callback: async (restApi) => { + const workbook = await restApi.workbooksMethods.getWorkbook({ + siteId: restApi.siteId, + workbookId, + }); + + const workbookName = workbook.name; + const viewName = workbook.defaultViewId + ? workbook.views?.view.find((v) => v.id === workbook.defaultViewId)?.name + : undefined; + + if (!viewName) { + return new Err({ + type: 'view-not-found', + message: 'No view ID provided and no default view for workbook found.', + } as const); + } + + const vizqlClient = await getNewVizqlApiInstanceAsync({ + baseUrl: config.server || getTableauAuthInfo(authInfo)?.server || '', + requestId, + server, + maxRequestTimeoutMs: config.maxRequestTimeoutMs, + signal, + }); + + const { + result: { sessionid }, + headerValue: globalSessionHeader, + } = await useHeaderExtractorPlugin({ + client: vizqlClient, + headerName: 'global-session-header', + timeoutMs: config.maxRequestTimeoutMs, + signal, + clientCallback: async (client) => { + return await client.startSession(undefined, { + params: { + siteName: config.siteName, + workbookName, + viewName, + }, + headers: { + Cookie: `workgroup_session_id=${restApi.creds.token};`, + }, + }); + }, + }); + + return new Ok({ + sessionid, + globalSessionHeader, + }); + }, + }); + + return result; + }, + constrainSuccessResult: (result) => { + return { + type: 'success', + result, + }; + }, + getErrorText: (error: CreateWorkbookSessionError) => { + switch (error.type) { + case 'workbook-not-allowed': + return error.message; + case 'view-not-found': + return error.message; + } + }, + }); + }, + }); + + return createWorkbookSessionTool; +}; diff --git a/src/vizqlApiInstance.ts b/src/vizqlApiInstance.ts new file mode 100644 index 00000000..266d9440 --- /dev/null +++ b/src/vizqlApiInstance.ts @@ -0,0 +1,62 @@ +import { RequestId } from '@modelcontextprotocol/sdk/types.js'; + +import { + addInterceptors, + getRequestErrorInterceptor, + getRequestInterceptor, + getResponseErrorInterceptor, + getResponseInterceptor, +} from './apiClients.js'; +import { log } from './logging/log.js'; +import { getClient, VizqlClient } from './sdks/tableau-vizql/client.js'; +import { Server } from './server.js'; + +export const getNewVizqlApiInstanceAsync = async ({ + baseUrl, + requestId, + server, + maxRequestTimeoutMs, + signal, +}: { + baseUrl: string; + requestId: RequestId; + server: Server; + maxRequestTimeoutMs: number; + signal: AbortSignal; +}): Promise => { + signal.addEventListener( + 'abort', + () => { + log.info( + server, + { + type: 'request-cancelled', + requestId, + reason: signal.reason, + }, + { logger: server.name, requestId }, + ); + }, + { once: true }, + ); + + const client = getClient(baseUrl, { + timeout: maxRequestTimeoutMs, + signal, + }); + + addInterceptors( + baseUrl, + client.axios.interceptors, + [ + getRequestInterceptor(server, requestId, 'vizql-api'), + getRequestErrorInterceptor(server, requestId, 'vizql-api'), + ], + [ + getResponseInterceptor(server, requestId, 'vizql-api'), + getResponseErrorInterceptor(server, requestId, 'vizql-api'), + ], + ); + + return client; +};