diff --git a/.changeset/fix-5xx-error-classification.md b/.changeset/fix-5xx-error-classification.md new file mode 100644 index 00000000..61318df2 --- /dev/null +++ b/.changeset/fix-5xx-error-classification.md @@ -0,0 +1,5 @@ +--- +'@repo/mcp-common': patch +--- + +Classify upstream 4xx errors correctly instead of returning 500, and set reportToSentry flag to avoid alerting on expected client errors diff --git a/packages/mcp-common/src/__mocks__/cloudflare.ts b/packages/mcp-common/src/__mocks__/cloudflare.ts new file mode 100644 index 00000000..c3c5a505 --- /dev/null +++ b/packages/mcp-common/src/__mocks__/cloudflare.ts @@ -0,0 +1,18 @@ +/** + * Mock for the 'cloudflare' SDK package. + * The real SDK transitively imports ReadStream from node:fs which is unavailable in workerd. + * This mock is wired via resolve.alias in vitest.config.ts so the real module is never loaded. + */ + +export class Cloudflare { + constructor(_opts?: Record) {} +} + +export class APIError extends Error { + status: number + constructor(status: number, message?: string) { + super(message) + this.status = status + this.name = 'APIError' + } +} diff --git a/packages/mcp-common/src/cloudflare-api.spec.ts b/packages/mcp-common/src/cloudflare-api.spec.ts new file mode 100644 index 00000000..bcf26dc1 --- /dev/null +++ b/packages/mcp-common/src/cloudflare-api.spec.ts @@ -0,0 +1,159 @@ +import { fetchMock } from 'cloudflare:test' +import { beforeAll, describe, expect, it } from 'vitest' + +import { fetchCloudflareApi } from './cloudflare-api' +import { McpError } from './mcp-error' + +beforeAll(() => { + fetchMock.activate() + fetchMock.disableNetConnect() +}) + +describe('fetchCloudflareApi', () => { + const baseParams = { + endpoint: '/workers/scripts', + accountId: 'test-account-id', + apiToken: 'test-api-token', + } + + it('returns parsed data on success', async () => { + const responseData = { result: { id: 'test-script' } } + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(200, responseData) + + const result = await fetchCloudflareApi(baseParams) + expect(result).toEqual(responseData) + }) + + it('throws McpError with status 404 for not found', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(404, JSON.stringify({ errors: [{ message: 'Script not found' }] })) + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(404) + expect(err.reportToSentry).toBe(false) + expect(err.message).toBe('Cloudflare API request failed') + expect(err.internalMessage).toContain('Script not found') + } + }) + + it('throws McpError with status 403 for forbidden', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(403, JSON.stringify({ errors: [{ message: 'Forbidden' }] })) + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(403) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError with status 429 for rate limiting', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(429, JSON.stringify({ errors: [{ message: 'Rate limited' }] })) + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(429) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError with status 502 for upstream 500 (bad gateway)', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(500, 'Internal Server Error') + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.message).toBe('Upstream Cloudflare API unavailable') + expect(err.reportToSentry).toBe(true) + expect(err.internalMessage).toContain('Cloudflare API 500') + } + }) + + it('throws McpError with status 502 for upstream 502', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(502, 'Bad Gateway') + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.message).toBe('Upstream Cloudflare API unavailable') + expect(err.reportToSentry).toBe(true) + expect(err.internalMessage).toContain('Cloudflare API 502') + } + }) + + it('preserves error text in internalMessage (not user-facing message)', async () => { + const errorBody = '{"errors":[{"message":"Worker not found","code":10007}]}' + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(404, errorBody) + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.message).toBe('Cloudflare API request failed') + expect(err.internalMessage).toContain('Worker not found') + } + }) +}) diff --git a/packages/mcp-common/src/cloudflare-api.ts b/packages/mcp-common/src/cloudflare-api.ts index 447e411f..b53d834e 100644 --- a/packages/mcp-common/src/cloudflare-api.ts +++ b/packages/mcp-common/src/cloudflare-api.ts @@ -1,6 +1,8 @@ import { Cloudflare } from 'cloudflare' import { env } from 'cloudflare:workers' +import { throwUpstreamApiError } from './mcp-error' + import type { z } from 'zod' export function getCloudflareClient(apiToken: string) { @@ -55,8 +57,7 @@ export async function fetchCloudflareApi({ }) if (!response.ok) { - const error = await response.text() - throw new Error(`Cloudflare API request failed: ${error}`) + throwUpstreamApiError(response.status, 'Cloudflare API', await response.text()) } const data = await response.json() diff --git a/packages/mcp-common/src/cloudflare-auth.spec.ts b/packages/mcp-common/src/cloudflare-auth.spec.ts new file mode 100644 index 00000000..f7734175 --- /dev/null +++ b/packages/mcp-common/src/cloudflare-auth.spec.ts @@ -0,0 +1,329 @@ +import { fetchMock } from 'cloudflare:test' +import { beforeAll, describe, expect, it } from 'vitest' + +import { getAuthToken, refreshAuthToken } from './cloudflare-auth' +import { McpError } from './mcp-error' + +beforeAll(() => { + fetchMock.activate() + fetchMock.disableNetConnect() +}) + +const validTokenResponse = { + access_token: 'test-access-token', + expires_in: 3600, + refresh_token: 'test-refresh-token', + scope: 'read write', + token_type: 'bearer', +} + +describe('getAuthToken', () => { + const baseParams = { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + redirect_uri: 'https://example.com/callback', + code_verifier: 'test-verifier', + code: 'test-code', + } + + it('throws McpError 400 for missing code', async () => { + try { + await getAuthToken({ ...baseParams, code: '' }) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).code).toBe(400) + } + }) + + it('returns parsed token on success', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(200, validTokenResponse) + + const result = await getAuthToken(baseParams) + expect(result.access_token).toBe('test-access-token') + expect(result.refresh_token).toBe('test-refresh-token') + expect(result.expires_in).toBe(3600) + }) + + it('throws McpError with upstream status for 400 (expired/invalid grant)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 400, + JSON.stringify({ + error: 'invalid_grant', + error_description: 'The authorization code has expired', + }) + ) + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(400) + expect(err.message).toBe('Authorization grant is invalid, expired, or revoked') + expect(err.reportToSentry).toBe(false) + expect(err.internalMessage).toContain('Upstream 400') + expect(err.internalMessage).toContain('The authorization code has expired') + } + }) + + it('throws McpError with upstream status for 401 (bad client credentials)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 401, + JSON.stringify({ + error: 'invalid_client', + error_description: 'Invalid client credentials', + }) + ) + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(401) + expect(err.message).toBe('Client authentication failed') + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError with upstream status for 403 (insufficient permissions)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 403, + JSON.stringify({ + error: 'access_denied', + error_description: 'Insufficient permissions', + }) + ) + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(403) + expect(err.message).toBe('Access denied') + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError with upstream status for 429 (rate limited)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 429, + JSON.stringify({ + error: 'rate_limited', + error_description: 'Too many requests', + }) + ) + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(429) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError 502 for upstream 500 (server error)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(500, 'Internal Server Error') + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.message).toBe('Upstream token service unavailable') + expect(err.reportToSentry).toBe(true) + expect(err.internalMessage).toContain('Upstream 500') + } + }) + + it('throws McpError 502 for upstream 503 (service unavailable)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(503, 'Service Unavailable') + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.reportToSentry).toBe(true) + } + }) + + it('uses fallback message when upstream body is not JSON', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(400, 'Bad Request - plain text') + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(400) + expect(err.message).toBe('Token exchange failed') + expect(err.reportToSentry).toBe(false) + } + }) +}) + +describe('refreshAuthToken', () => { + const baseParams = { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + refresh_token: 'test-refresh-token', + } + + it('returns parsed token on success', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(200, validTokenResponse) + + const result = await refreshAuthToken(baseParams) + expect(result.access_token).toBe('test-access-token') + expect(result.refresh_token).toBe('test-refresh-token') + }) + + it('throws McpError with upstream status for 400 (expired refresh token)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 400, + JSON.stringify({ + error: 'invalid_grant', + error_description: 'The refresh token has expired', + }) + ) + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(400) + expect(err.message).toBe('Authorization grant is invalid, expired, or revoked') + expect(err.reportToSentry).toBe(false) + expect(err.internalMessage).toContain('Upstream 400') + expect(err.internalMessage).toContain('The refresh token has expired') + } + }) + + it('throws McpError with upstream status for 401 (invalid client)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 401, + JSON.stringify({ + error: 'invalid_client', + error_description: 'Bad client credentials', + }) + ) + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(401) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws McpError 502 for upstream 500', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(500, 'Server Error') + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.message).toBe('Upstream token service unavailable') + expect(err.reportToSentry).toBe(true) + } + }) + + it('uses fallback message when upstream error code is unknown', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(400, JSON.stringify({ error: 'some_unknown_error' })) + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(400) + expect(err.message).toBe('Token refresh failed') + expect(err.reportToSentry).toBe(false) + } + }) + + it('maps known error codes to safe messages instead of forwarding error_description', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply( + 400, + JSON.stringify({ + error: 'invalid_grant', + error_description: 'Internal: token xyz expired at 2024-01-01', + }) + ) + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.message).toBe('Authorization grant is invalid, expired, or revoked') + // Raw upstream detail preserved in internalMessage only + expect(err.internalMessage).toContain('Internal: token xyz expired') + } + }) +}) diff --git a/packages/mcp-common/src/cloudflare-auth.ts b/packages/mcp-common/src/cloudflare-auth.ts index 37d27c82..240829a0 100644 --- a/packages/mcp-common/src/cloudflare-auth.ts +++ b/packages/mcp-common/src/cloudflare-auth.ts @@ -1,9 +1,53 @@ import { z } from 'zod' -import { McpError } from './mcp-error' +import { McpError, safeStatusCode } from './mcp-error' import type { AuthRequest } from '@cloudflare/workers-oauth-provider' +/** Maps known OAuth error codes to safe client-facing messages */ +const SAFE_TOKEN_ERROR_MESSAGES: Record = { + invalid_grant: 'Authorization grant is invalid, expired, or revoked', + invalid_client: 'Client authentication failed', + invalid_request: 'Invalid token request', + unauthorized_client: 'Client is not authorized for this grant type', + unsupported_grant_type: 'Unsupported grant type', + invalid_scope: 'Requested scope is invalid', + access_denied: 'Access denied', +} + +/** + * Throw an McpError for an upstream token endpoint failure. + * 4xx: preserves status with a safe message mapped from the OAuth error code. + * 5xx: maps to 502 Bad Gateway. + */ +function throwUpstreamTokenError(status: number, body: string, context: string): never { + let upstreamError: { error?: string } = {} + try { + upstreamError = JSON.parse(body) + } catch { + // upstream may return non-JSON error bodies + } + + // Truncate body to avoid capturing excessive data in logs/Sentry + const truncatedBody = body.length > 500 ? body.slice(0, 500) + '...' : body + + if (status >= 400 && status < 500) { + throw new McpError( + SAFE_TOKEN_ERROR_MESSAGES[upstreamError.error || ''] || context, + safeStatusCode(status), + { + reportToSentry: false, + internalMessage: `Upstream ${status}: ${truncatedBody}`, + } + ) + } + + throw new McpError('Upstream token service unavailable', 502, { + reportToSentry: true, + internalMessage: `Upstream ${status}: ${truncatedBody}`, + }) +} + // Constants const PKCE_CHARSET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~' const RECOMMENDED_CODE_VERIFIER_LENGTH = 96 @@ -152,8 +196,7 @@ export async function getAuthToken({ }) if (!resp.ok) { - console.log(await resp.text()) - throw new McpError('Failed to get OAuth token', 500, { reportToSentry: true }) + throwUpstreamTokenError(resp.status, await resp.text(), 'Token exchange failed') } return AuthorizationToken.parse(await resp.json()) @@ -183,8 +226,7 @@ export async function refreshAuthToken({ }, }) if (!resp.ok) { - console.log(await resp.text()) - throw new McpError('Failed to get OAuth token', 500, { reportToSentry: true }) + throwUpstreamTokenError(resp.status, await resp.text(), 'Token refresh failed') } return AuthorizationToken.parse(await resp.json()) diff --git a/packages/mcp-common/src/cloudflare-oauth-handler.spec.ts b/packages/mcp-common/src/cloudflare-oauth-handler.spec.ts new file mode 100644 index 00000000..3e3386d9 --- /dev/null +++ b/packages/mcp-common/src/cloudflare-oauth-handler.spec.ts @@ -0,0 +1,486 @@ +import { GrantType } from '@cloudflare/workers-oauth-provider' +import { fetchMock } from 'cloudflare:test' +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' + +import { refreshAuthToken } from './cloudflare-auth' +import { getUserAndAccounts, handleTokenExchangeCallback } from './cloudflare-oauth-handler' +import { McpError } from './mcp-error' +import { OAuthError } from './workers-oauth-utils' + +import type { TokenExchangeCallbackOptions } from '@cloudflare/workers-oauth-provider' + +// Mock the refreshAuthToken function +vi.mock('./cloudflare-auth', () => ({ + refreshAuthToken: vi.fn(), + getAuthToken: vi.fn(), + generatePKCECodes: vi.fn(), + getAuthorizationURL: vi.fn(), +})) + +const mockRefreshAuthToken = vi.mocked(refreshAuthToken) + +beforeAll(() => { + fetchMock.activate() + fetchMock.disableNetConnect() +}) + +beforeEach(() => { + vi.resetAllMocks() +}) + +afterEach(() => { + vi.restoreAllMocks() +}) + +function makeRefreshOptions(propsOverride: Record): TokenExchangeCallbackOptions { + return { + grantType: GrantType.REFRESH_TOKEN, + props: propsOverride, + clientId: 'test', + userId: 'test-user', + scope: [], + requestedScope: [], + } +} + +describe('handleTokenExchangeCallback', () => { + const clientId = 'test-client-id' + const clientSecret = 'test-client-secret' + + describe('account_token refresh attempt', () => { + it('throws OAuthError invalid_grant for account token refresh', async () => { + const options = makeRefreshOptions({ + type: 'account_token', + accessToken: 'test-token', + account: { name: 'test', id: 'test-id' }, + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('invalid_grant') + expect(err.statusCode).toBe(400) + expect(err.description).toBe('Account tokens cannot be refreshed') + } + }) + }) + + describe('missing refresh token', () => { + it('throws OAuthError invalid_grant when refreshToken is missing', async () => { + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + // no refreshToken + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('invalid_grant') + expect(err.statusCode).toBe(400) + expect(err.description).toBe('No refresh token available for this grant') + } + }) + }) + + describe('successful refresh', () => { + it('returns new props and TTL on successful refresh', async () => { + mockRefreshAuthToken.mockResolvedValueOnce({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 7200, + scope: 'read write', + token_type: 'bearer', + }) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'old-access-token', + refreshToken: 'old-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + const result = await handleTokenExchangeCallback(options, clientId, clientSecret) + expect(result).toBeDefined() + expect(result!.accessTokenTTL).toBe(7200) + expect(result!.newProps).toMatchObject({ + accessToken: 'new-access-token', + refreshToken: 'new-refresh-token', + }) + }) + }) + + describe('converts upstream McpErrors from refreshAuthToken to OAuthError', () => { + it('converts McpError 400 from expired upstream refresh token to OAuthError invalid_grant', async () => { + mockRefreshAuthToken.mockRejectedValueOnce( + new McpError('Authorization grant is invalid, expired, or revoked', 400, { + reportToSentry: false, + internalMessage: 'Upstream 400: {"error":"invalid_grant"}', + }) + ) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + refreshToken: 'expired-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('invalid_grant') + expect(err.statusCode).toBe(400) + expect(err.description).toBe('Authorization grant is invalid, expired, or revoked') + } + }) + + it('converts McpError 502 from upstream server error to OAuthError server_error', async () => { + mockRefreshAuthToken.mockRejectedValueOnce( + new McpError('Upstream token service unavailable', 502, { + reportToSentry: true, + internalMessage: 'Upstream 500: Internal Server Error', + }) + ) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + refreshToken: 'valid-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('server_error') + expect(err.statusCode).toBe(500) + expect(err.description).toBe('Upstream token service unavailable') + } + }) + + it('converts McpError 429 to OAuthError temporarily_unavailable', async () => { + mockRefreshAuthToken.mockRejectedValueOnce( + new McpError('Rate limited, try again later', 429, { + reportToSentry: false, + internalMessage: 'Upstream 429', + }) + ) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + refreshToken: 'valid-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('temporarily_unavailable') + expect(err.statusCode).toBe(503) + } + }) + + it('converts McpError 401 to OAuthError invalid_client', async () => { + mockRefreshAuthToken.mockRejectedValueOnce( + new McpError('Access token is invalid or expired', 401, { + reportToSentry: false, + internalMessage: 'Upstream 401', + }) + ) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + refreshToken: 'valid-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.code).toBe('invalid_client') + expect(err.statusCode).toBe(401) + } + }) + + it('re-throws non-McpError errors unchanged', async () => { + const genericError = new Error('unexpected failure') + mockRefreshAuthToken.mockRejectedValueOnce(genericError) + + const options = makeRefreshOptions({ + type: 'user_token', + accessToken: 'test-token', + refreshToken: 'valid-refresh-token', + user: { id: 'user-1', email: 'user@example.com' }, + accounts: [{ name: 'test', id: 'test-id' }], + }) + + try { + await handleTokenExchangeCallback(options, clientId, clientSecret) + expect.unreachable() + } catch (e) { + expect(e).toBe(genericError) + expect(e).not.toBeInstanceOf(OAuthError) + } + }) + }) + + describe('non-refresh grant types', () => { + it('returns undefined for authorization_code grant type', async () => { + const options: TokenExchangeCallbackOptions = { + grantType: GrantType.AUTHORIZATION_CODE, + props: {}, + clientId: 'test', + userId: 'test-user', + scope: [], + requestedScope: [], + } + + const result = await handleTokenExchangeCallback(options, clientId, clientSecret) + expect(result).toBeUndefined() + }) + }) +}) + +function mockUserResponse(status: number, body?: unknown) { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ path: '/client/v4/user', method: 'GET' }) + .reply(status, body ? JSON.stringify(body) : '') +} + +function mockAccountsResponse(status: number, body?: unknown) { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ path: '/client/v4/accounts', method: 'GET' }) + .reply(status, body ? JSON.stringify(body) : '') +} + +const v4User = { + success: true, + result: { id: 'user-1', email: 'user@example.com' }, + errors: [], + messages: [], +} +const v4Accounts = { + success: true, + result: [{ id: 'acc-1', name: 'My Account' }], + errors: [], + messages: [], +} + +describe('getUserAndAccounts', () => { + it('returns user and accounts on success', async () => { + mockUserResponse(200, v4User) + mockAccountsResponse(200, v4Accounts) + + const result = await getUserAndAccounts('test-token') + expect(result.user).toEqual({ id: 'user-1', email: 'user@example.com' }) + expect(result.accounts).toEqual([{ id: 'acc-1', name: 'My Account' }]) + }) + + it('returns user=null for account-scoped tokens (user 401, accounts 200)', async () => { + mockUserResponse(401, { errors: [{ message: 'Unauthorized' }] }) + mockAccountsResponse(200, v4Accounts) + + const result = await getUserAndAccounts('test-token') + expect(result.user).toBeNull() + expect(result.accounts).toEqual([{ id: 'acc-1', name: 'My Account' }]) + }) + + describe('combined failure (both endpoints fail)', () => { + it('throws 502 when any endpoint returns 5xx', async () => { + mockUserResponse(401) + mockAccountsResponse(500) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.reportToSentry).toBe(true) + } + }) + + it('throws 429 when rate limited', async () => { + mockUserResponse(429) + mockAccountsResponse(429) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(429) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws 401 for expired token', async () => { + mockUserResponse(401) + mockAccountsResponse(401) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(401) + expect(err.reportToSentry).toBe(false) + } + }) + + it('throws 403 for insufficient permissions', async () => { + mockUserResponse(403) + mockAccountsResponse(403) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(403) + expect(err.reportToSentry).toBe(false) + } + }) + }) + + it('throws 401 when no user or account information is returned', async () => { + mockUserResponse(200, { success: true, result: null, errors: [], messages: [] }) + mockAccountsResponse(200, { success: true, result: [], errors: [], messages: [] }) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(401) + expect(err.message).toBe('Failed to verify token: no user or account information') + } + }) + + it('gracefully handles malformed JSON in /user response', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ path: '/client/v4/user', method: 'GET' }) + .reply(200, 'not json') + mockAccountsResponse(200, v4Accounts) + + // Should still return accounts even if user parsing fails + const result = await getUserAndAccounts('test-token') + expect(result.user).toBeNull() + expect(result.accounts).toEqual([{ id: 'acc-1', name: 'My Account' }]) + }) + + describe('mixed-status priority in combined failures', () => { + it('prioritizes 5xx over 429 (401+500 → 502)', async () => { + mockUserResponse(401) + mockAccountsResponse(500) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.reportToSentry).toBe(true) + } + }) + + it('prioritizes 429 over 401 (401+429 → 429)', async () => { + mockUserResponse(401) + mockAccountsResponse(429) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(429) + expect(err.reportToSentry).toBe(false) + } + }) + + it('prioritizes 5xx over 403 (403+500 → 502)', async () => { + mockUserResponse(403) + mockAccountsResponse(500) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.reportToSentry).toBe(true) + } + }) + }) + + describe('accounts failure with user success', () => { + it('throws when accounts returns 500 even if user succeeds', async () => { + mockUserResponse(200, v4User) + mockAccountsResponse(500) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(502) + expect(err.reportToSentry).toBe(true) + } + }) + + it('throws when accounts returns 403 even if user succeeds', async () => { + mockUserResponse(200, v4User) + mockAccountsResponse(403) + + try { + await getUserAndAccounts('test-token') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + const err = e as McpError + expect(err.code).toBe(403) + expect(err.reportToSentry).toBe(false) + } + }) + }) +}) diff --git a/packages/mcp-common/src/cloudflare-oauth-handler.ts b/packages/mcp-common/src/cloudflare-oauth-handler.ts index 22e823ab..d5241615 100644 --- a/packages/mcp-common/src/cloudflare-oauth-handler.ts +++ b/packages/mcp-common/src/cloudflare-oauth-handler.ts @@ -1,3 +1,4 @@ +import { GrantType } from '@cloudflare/workers-oauth-provider' import { zValidator } from '@hono/zod-validator' import { Hono } from 'hono' import { z } from 'zod' @@ -9,7 +10,7 @@ import { getAuthToken, refreshAuthToken, } from './cloudflare-auth' -import { McpError } from './mcp-error' +import { McpError, safeStatusCode, throwUpstreamApiError } from './mcp-error' import { useSentry } from './sentry' import { V4Schema } from './v4-api' import { @@ -33,6 +34,26 @@ import type { Context } from 'hono' import type { MetricsTracker } from '../../mcp-observability/src' import type { BaseHonoContext } from './sentry' +/** + * Converts an McpError into an OAuth 2.1 spec-compliant JSON error response. + * + * Maps HTTP status codes to the standard OAuth error codes defined in + * https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-13#section-3.2.4 + */ +function mcpErrorToOAuthResponse(e: McpError): Response { + let oauthCode: string + if (e.code >= 500) { + oauthCode = 'server_error' + } else if (e.code === 429) { + oauthCode = 'temporarily_unavailable' + } else if (e.code === 401 || e.code === 403) { + oauthCode = 'access_denied' + } else { + oauthCode = 'invalid_request' + } + return new OAuthError(oauthCode, e.message, e.code >= 500 ? 500 : e.code).toResponse() +} + type AuthContext = { Bindings: { OAUTH_PROVIDER: OAuthHelpers @@ -78,6 +99,47 @@ const UserAuthProps = z.object({ export type AuthProps = z.infer const AuthProps = z.discriminatedUnion('type', [AccountAuthProps, UserAuthProps]) +/** + * Throws an McpError for combined /user + /accounts failures. + * Uses priority-based classification matching cloudflare-mcp patterns. + */ +function throwCombinedApiError(userStatus: number, accountsStatus: number): never { + const statuses = [userStatus, accountsStatus] + + if (statuses.some((s) => s >= 500)) { + throw new McpError('Cloudflare API is temporarily unavailable', 502, { + reportToSentry: true, + internalMessage: `Upstream user=${userStatus}, accounts=${accountsStatus}`, + }) + } + + if (statuses.includes(429)) { + throw new McpError('Rate limited, try again later', 429, { + reportToSentry: false, + internalMessage: `Upstream user=${userStatus}, accounts=${accountsStatus}`, + }) + } + + if (statuses.includes(401)) { + throw new McpError('Access token is invalid or expired', 401, { + reportToSentry: false, + internalMessage: `Upstream user=${userStatus}, accounts=${accountsStatus}`, + }) + } + + if (statuses.includes(403)) { + throw new McpError('Insufficient permissions', 403, { + reportToSentry: false, + internalMessage: `Upstream user=${userStatus}, accounts=${accountsStatus}`, + }) + } + + throw new McpError('Failed to verify token', safeStatusCode(userStatus), { + reportToSentry: false, + internalMessage: `Upstream user=${userStatus}, accounts=${accountsStatus}`, + }) +} + export async function getUserAndAccounts( accessToken: string, devModeHeaders?: HeadersInit @@ -88,32 +150,85 @@ export async function getUserAndAccounts( Authorization: `Bearer ${accessToken}`, } - // Fetch the user & accounts info from Cloudflare - const [userResponse, accountsResponse] = await Promise.all([ - fetch('https://api.cloudflare.com/client/v4/user', { - headers, - }), - fetch('https://api.cloudflare.com/client/v4/accounts', { - headers, - }), - ]) - - const { result: user } = V4Schema(UserSchema).parse(await userResponse.json()) - const { result: accounts } = V4Schema(AccountsSchema).parse(await accountsResponse.json()) - if (!user || !userResponse.ok) { - // If accounts is present, then assume that we have an account scoped token - if (accounts !== null) { - return { user: null, accounts } + // Fetch the user & accounts info from Cloudflare in parallel + let userResponse: Response + let accountsResponse: Response + try { + ;[userResponse, accountsResponse] = await Promise.all([ + fetch('https://api.cloudflare.com/client/v4/user', { headers }), + fetch('https://api.cloudflare.com/client/v4/accounts', { headers }), + ]) + } catch (error) { + console.error('Cloudflare API request failed', error) + throw new McpError('Cloudflare API is temporarily unavailable', 502, { + reportToSentry: true, + internalMessage: `Network error: ${error instanceof Error ? error.message : String(error)}`, + }) + } + + // If both endpoints failed, use priority-based error classification + if (!userResponse.ok && !accountsResponse.ok) { + console.error( + `Cloudflare API error: user=${userResponse.status}, accounts=${accountsResponse.status}` + ) + throwCombinedApiError(userResponse.status, accountsResponse.status) + } + + // Parse accounts with safeParse for graceful degradation + let accounts: AccountsSchema = [] + if (accountsResponse.ok) { + try { + const json = await accountsResponse.json() + const parsed = V4Schema(AccountsSchema).safeParse(json) + if (parsed.success) { + accounts = parsed.data.result ?? [] + } else { + console.error('Cloudflare API /accounts payload did not match expected shape', parsed.error) + } + } catch (error) { + console.error('Cloudflare API /accounts response is not valid JSON', error) + } + } else if (userResponse.ok) { + // User succeeded but accounts failed — surface the accounts error + // (5xx should be reported, 4xx like 403 may indicate insufficient scopes) + console.error(`Cloudflare API /accounts failed with status ${accountsResponse.status}`) + throwUpstreamApiError(accountsResponse.status, 'Cloudflare API /accounts') + } + + // Parse user with safeParse for graceful degradation + let user: UserSchema | null = null + if (userResponse.ok) { + try { + const json = await userResponse.json() + const parsed = V4Schema(UserSchema).safeParse(json) + if (parsed.success) { + user = parsed.data.result ?? null + } else { + console.error('Cloudflare API /user payload did not match expected shape', parsed.error) + } + } catch (error) { + console.error('Cloudflare API /user response is not valid JSON', error) } - console.log(user) - throw new McpError('Failed to fetch user', 500, { reportToSentry: true }) + } else if (accounts.length > 0) { + // User endpoint failed but accounts succeeded — account-scoped token + return { user: null, accounts } + } else { + throwUpstreamApiError(userResponse.status, 'Cloudflare API /user') } - if (!accounts || !accountsResponse.ok) { - console.log(accounts) - throw new McpError('Failed to fetch accounts', 500, { reportToSentry: true }) + + if (user) { + return { user, accounts } + } + + // Account-scoped token — user is null but accounts are present + if (accounts.length > 0) { + return { user: null, accounts } } - return { user, accounts } + throw new McpError('Failed to verify token: no user or account information', 401, { + reportToSentry: false, + internalMessage: `user=${userResponse.status}, accounts=${accountsResponse.status}`, + }) } /** @@ -158,26 +273,51 @@ export async function handleTokenExchangeCallback( clientSecret: string ): Promise { // options.props contains the current props - if (options.grantType === 'refresh_token') { + if (options.grantType === GrantType.REFRESH_TOKEN) { const props = AuthProps.parse(options.props) if (props.type === 'account_token') { - // Refreshing an account_token should not be possible, as we only do this for user tokens - throw new McpError('Internal Server Error', 500) + // Account tokens cannot be refreshed — this is a client error, not a server error + throw new OAuthError('invalid_grant', 'Account tokens cannot be refreshed', 400) } if (!props.refreshToken) { - throw new McpError('Missing refreshToken', 500) + throw new OAuthError('invalid_grant', 'No refresh token available for this grant', 400) } - // handle token refreshes - const { - access_token: accessToken, - refresh_token: refreshToken, - expires_in, - } = await refreshAuthToken({ - client_id: clientId, - client_secret: clientSecret, - refresh_token: props.refreshToken, - }) + // handle token refreshes — convert upstream McpErrors to OAuth-compliant errors + let accessToken: string + let refreshToken: string + let expires_in: number + try { + const result = await refreshAuthToken({ + client_id: clientId, + client_secret: clientSecret, + refresh_token: props.refreshToken, + }) + accessToken = result.access_token + refreshToken = result.refresh_token + expires_in = result.expires_in + } catch (e) { + if (e instanceof McpError) { + // Map upstream failures to OAuth error codes per RFC 6749 + let oauthCode: string + let httpStatus: number + if (e.code >= 500) { + oauthCode = 'server_error' + httpStatus = 500 + } else if (e.code === 429) { + oauthCode = 'temporarily_unavailable' + httpStatus = 503 + } else if (e.code === 401) { + oauthCode = 'invalid_client' + httpStatus = 401 + } else { + oauthCode = 'invalid_grant' + httpStatus = 400 + } + throw new OAuthError(oauthCode, e.message, httpStatus) + } + throw e + } return { newProps: { @@ -315,10 +455,10 @@ export function createAuthHandlers({ return e.toResponse() } if (e instanceof McpError) { - return c.text(e.message, { status: e.code }) + return mcpErrorToOAuthResponse(e) } console.error(e) - return c.text('Internal Error', 500) + return new OAuthError('server_error', 'Internal Error', 500).toResponse() } }) @@ -383,8 +523,11 @@ export function createAuthHandlers({ if (e instanceof OAuthError) { return e.toResponse() } + if (e instanceof McpError) { + return mcpErrorToOAuthResponse(e) + } console.error(e) - return c.text('Internal Error', 500) + return new OAuthError('server_error', 'Internal Error', 500).toResponse() } }) @@ -465,9 +608,9 @@ export function createAuthHandlers({ return e.toResponse() } if (e instanceof McpError) { - return c.text(e.message, { status: e.code }) + return mcpErrorToOAuthResponse(e) } - return c.text('Internal Error', 500) + return new OAuthError('server_error', 'Internal Error', 500).toResponse() } }) diff --git a/packages/mcp-common/src/mcp-error.ts b/packages/mcp-common/src/mcp-error.ts index 1bef1fa9..c1e0a562 100644 --- a/packages/mcp-common/src/mcp-error.ts +++ b/packages/mcp-common/src/mcp-error.ts @@ -1,4 +1,39 @@ -import type { ContentfulStatusCode } from 'hono/utils/http-status' +import type { ClientErrorStatusCode, ContentfulStatusCode } from 'hono/utils/http-status' + +const KNOWN_CLIENT_ERROR_CODES = new Set([ + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, + 421, 422, 423, 424, 425, 426, 428, 429, 431, 451, +]) + +/** + * Safely maps an HTTP status code to a ContentfulStatusCode. + * Unknown 4xx codes fall back to 400; 5xx codes map to 502. + */ +export function safeStatusCode(status: number): ContentfulStatusCode { + if (KNOWN_CLIENT_ERROR_CODES.has(status)) return status as ClientErrorStatusCode + if (status >= 400 && status < 500) return 400 + if (status >= 500) return 502 + return 500 +} + +/** + * Throws an McpError for an upstream API failure. + * 4xx: preserves the status code with reportToSentry=false. + * 5xx: maps to 502 Bad Gateway with reportToSentry=true. + */ +export function throwUpstreamApiError(status: number, context: string, rawBody?: string): never { + const is5xx = status >= 500 + throw new McpError( + is5xx ? `Upstream ${context} unavailable` : `${context} request failed`, + safeStatusCode(is5xx ? 502 : status), + { + reportToSentry: is5xx, + internalMessage: rawBody + ? `${context} ${status}: ${rawBody.slice(0, 500)}` + : `${context} returned ${status}`, + } + ) +} export class McpError extends Error { public code: ContentfulStatusCode diff --git a/packages/mcp-common/src/sentry.spec.ts b/packages/mcp-common/src/sentry.spec.ts new file mode 100644 index 00000000..8c0535c9 --- /dev/null +++ b/packages/mcp-common/src/sentry.spec.ts @@ -0,0 +1,140 @@ +import { fetchMock } from 'cloudflare:test' +import { beforeAll, describe, expect, it } from 'vitest' + +import { fetchCloudflareApi } from './cloudflare-api' +import { getAuthToken, refreshAuthToken } from './cloudflare-auth' +import { McpError } from './mcp-error' + +beforeAll(() => { + fetchMock.activate() + fetchMock.disableNetConnect() +}) + +/** + * Tests that the actual production code sets reportToSentry correctly: + * - 4xx upstream errors should have reportToSentry=false (expected client errors) + * - 5xx upstream errors (mapped to 502) should have reportToSentry=true (unexpected) + */ +describe('reportToSentry flag in production code paths', () => { + describe('fetchCloudflareApi', () => { + const baseParams = { + endpoint: '/workers/scripts', + accountId: 'test-account-id', + apiToken: 'test-api-token', + } + + it('sets reportToSentry=false for 4xx errors', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(404, JSON.stringify({ errors: [{ message: 'Not found' }] })) + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(false) + } + }) + + it('sets reportToSentry=true for 5xx errors', async () => { + fetchMock + .get('https://api.cloudflare.com') + .intercept({ + path: '/client/v4/accounts/test-account-id/workers/scripts', + method: 'GET', + }) + .reply(500, 'Internal Server Error') + + try { + await fetchCloudflareApi(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(true) + } + }) + }) + + describe('getAuthToken', () => { + const baseParams = { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + redirect_uri: 'https://example.com/callback', + code_verifier: 'test-verifier', + code: 'test-code', + } + + it('sets reportToSentry=false for 400 (invalid_grant)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(400, JSON.stringify({ error: 'invalid_grant' })) + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(false) + } + }) + + it('sets reportToSentry=true for 502 (upstream 500)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(500, 'Internal Server Error') + + try { + await getAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(true) + } + }) + }) + + describe('refreshAuthToken', () => { + const baseParams = { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + refresh_token: 'test-refresh-token', + } + + it('sets reportToSentry=false for 400 (expired refresh token)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(400, JSON.stringify({ error: 'invalid_grant' })) + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(false) + } + }) + + it('sets reportToSentry=true for 502 (upstream 500)', async () => { + fetchMock + .get('https://dash.cloudflare.com') + .intercept({ path: '/oauth2/token', method: 'POST' }) + .reply(500, 'Server Error') + + try { + await refreshAuthToken(baseParams) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(McpError) + expect((e as McpError).reportToSentry).toBe(true) + } + }) + }) +}) diff --git a/packages/mcp-common/src/workers-oauth-utils.spec.ts b/packages/mcp-common/src/workers-oauth-utils.spec.ts new file mode 100644 index 00000000..012685e0 --- /dev/null +++ b/packages/mcp-common/src/workers-oauth-utils.spec.ts @@ -0,0 +1,351 @@ +import { describe, expect, it, vi } from 'vitest' + +import { OAuthError, parseRedirectApproval, validateOAuthState } from './workers-oauth-utils' + +describe('OAuthError', () => { + it('creates an error with code, description, and statusCode', () => { + const err = new OAuthError('invalid_request', 'Missing parameter', 400) + expect(err.code).toBe('invalid_request') + expect(err.description).toBe('Missing parameter') + expect(err.statusCode).toBe(400) + expect(err.name).toBe('OAuthError') + expect(err).toBeInstanceOf(Error) + }) + + it('generates a proper JSON response', () => { + const err = new OAuthError('access_denied', 'CSRF check failed', 403) + const response = err.toResponse() + expect(response.status).toBe(403) + expect(response.headers.get('Content-Type')).toBe('application/json') + }) + + it('includes error and error_description in response body', async () => { + const err = new OAuthError('invalid_request', 'Bad state', 400) + const response = err.toResponse() + const body = await response.json() + expect(body).toEqual({ + error: 'invalid_request', + error_description: 'Bad state', + }) + }) +}) + +describe('parseRedirectApproval', () => { + it('throws OAuthError 405 for non-POST requests', async () => { + const request = new Request('https://example.com/oauth/authorize', { + method: 'GET', + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(405) + expect(err.code).toBe('invalid_request') + } + }) + + it('throws OAuthError 400 for missing form token', async () => { + const formData = new FormData() + formData.set('state', btoa(JSON.stringify({ oauthReqInfo: { clientId: 'test' } }))) + // no csrf_token + + const request = new Request('https://example.com/oauth/authorize', { + method: 'POST', + body: formData, + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Missing required form token') + } + }) + + it('throws OAuthError 403 for form token mismatch', async () => { + const formData = new FormData() + formData.set('csrf_token', 'form-token') + formData.set('state', btoa(JSON.stringify({ oauthReqInfo: { clientId: 'test' } }))) + + const request = new Request('https://example.com/oauth/authorize', { + method: 'POST', + body: formData, + headers: { + Cookie: '__Host-CSRF_TOKEN=different-token', + }, + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(403) + expect(err.code).toBe('access_denied') + expect(err.description).toBe('Request validation failed') + } + }) + + it('throws OAuthError 400 for missing state', async () => { + const csrfToken = 'matching-token' + const formData = new FormData() + formData.set('csrf_token', csrfToken) + // no state + + const request = new Request('https://example.com/oauth/authorize', { + method: 'POST', + body: formData, + headers: { + Cookie: `__Host-CSRF_TOKEN=${csrfToken}`, + }, + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Missing state') + } + }) + + it('throws OAuthError 400 for malformed state encoding', async () => { + const csrfToken = 'matching-token' + const formData = new FormData() + formData.set('csrf_token', csrfToken) + formData.set('state', '!!!not-valid-base64!!!') + + const request = new Request('https://example.com/oauth/authorize', { + method: 'POST', + body: formData, + headers: { + Cookie: `__Host-CSRF_TOKEN=${csrfToken}`, + }, + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Invalid state encoding') + } + }) + + it('throws OAuthError 400 for invalid state data', async () => { + const csrfToken = 'matching-token' + const formData = new FormData() + formData.set('csrf_token', csrfToken) + formData.set('state', btoa(JSON.stringify({ noOauthReqInfo: true }))) + + const request = new Request('https://example.com/oauth/authorize', { + method: 'POST', + body: formData, + headers: { + Cookie: `__Host-CSRF_TOKEN=${csrfToken}`, + }, + }) + + try { + await parseRedirectApproval(request, 'test-secret') + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Invalid state data') + } + }) +}) + +describe('validateOAuthState', () => { + function createMockKV(data: Record = {}) { + return { + get: vi.fn(async (key: string) => data[key] ?? null), + put: vi.fn(async () => {}), + delete: vi.fn(async () => {}), + } as unknown as KVNamespace + } + + it('throws OAuthError 400 for missing state parameter', async () => { + const request = new Request('https://example.com/callback') + const kv = createMockKV() + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Missing state parameter') + } + }) + + it('throws OAuthError 400 for un-decodable state', async () => { + const request = new Request('https://example.com/callback?state=not-base64-json!') + const kv = createMockKV() + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Failed to decode state parameter') + } + }) + + it('throws OAuthError 400 for state without token', async () => { + // Valid base64 JSON but missing the 'state' field + const state = btoa(JSON.stringify({ other: 'data' })) + const request = new Request(`https://example.com/callback?state=${state}`) + const kv = createMockKV() + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('State token not found') + } + }) + + it('throws OAuthError 400 for expired/missing state in KV', async () => { + const stateToken = 'test-state-token' + const state = btoa(JSON.stringify({ state: stateToken })) + const request = new Request(`https://example.com/callback?state=${state}`) + const kv = createMockKV() // no data in KV + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('Invalid or expired state') + } + }) + + it('throws OAuthError 400 for expired authorization session', async () => { + const stateToken = 'test-state-token' + const state = btoa(JSON.stringify({ state: stateToken })) + const storedData = JSON.stringify({ + oauthReqInfo: { + clientId: 'test-client', + scope: ['read'], + state: 'test', + responseType: 'code', + redirectUri: 'https://example.com', + }, + codeVerifier: 'test-verifier', + }) + const kv = createMockKV({ [`oauth:state:${stateToken}`]: storedData }) + + const request = new Request(`https://example.com/callback?state=${state}`) + // no Cookie header + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toContain('session expired') + } + }) + + it('throws OAuthError 403 for state hash mismatch', async () => { + const stateToken = 'test-state-token' + const state = btoa(JSON.stringify({ state: stateToken })) + const storedData = JSON.stringify({ + oauthReqInfo: { + clientId: 'test-client', + scope: ['read'], + state: 'test', + responseType: 'code', + redirectUri: 'https://example.com', + }, + codeVerifier: 'test-verifier', + }) + const kv = createMockKV({ [`oauth:state:${stateToken}`]: storedData }) + + const request = new Request(`https://example.com/callback?state=${state}`, { + headers: { + Cookie: '__Host-CONSENTED_STATE=wrong-hash-value', + }, + }) + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(403) + expect(err.code).toBe('access_denied') + expect(err.description).toBe('Session validation failed') + } + }) + + it('throws OAuthError 400 for invalid stored state format', async () => { + const stateToken = 'test-state-token' + const state = btoa(JSON.stringify({ state: stateToken })) + + // Compute hash of stateToken to match cookie + const encoder = new TextEncoder() + const data = encoder.encode(stateToken) + const hashBuffer = await crypto.subtle.digest('SHA-256', data) + const hashArray = Array.from(new Uint8Array(hashBuffer)) + const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('') + + // Store invalid data (missing codeVerifier) + const storedData = JSON.stringify({ + oauthReqInfo: { clientId: 'test-client' }, + // missing codeVerifier + }) + const kv = createMockKV({ [`oauth:state:${stateToken}`]: storedData }) + + const request = new Request(`https://example.com/callback?state=${state}`, { + headers: { + Cookie: `__Host-CONSENTED_STATE=${hashHex}`, + }, + }) + + try { + await validateOAuthState(request, kv) + expect.unreachable() + } catch (e) { + expect(e).toBeInstanceOf(OAuthError) + const err = e as OAuthError + expect(err.statusCode).toBe(400) + expect(err.code).toBe('invalid_request') + expect(err.description).toBe('Invalid authorization state') + } + }) +}) diff --git a/packages/mcp-common/src/workers-oauth-utils.ts b/packages/mcp-common/src/workers-oauth-utils.ts index c0752f69..fbe481fc 100644 --- a/packages/mcp-common/src/workers-oauth-utils.ts +++ b/packages/mcp-common/src/workers-oauth-utils.ts @@ -551,14 +551,14 @@ export async function parseRedirectApproval( cookieSecret: string ): Promise { if (request.method !== 'POST') { - throw new Error('Invalid request method. Expected POST.') + throw new OAuthError('invalid_request', 'Invalid request method. Expected POST.', 405) } const formData = await request.formData() const tokenFromForm = formData.get('csrf_token') if (!tokenFromForm || typeof tokenFromForm !== 'string') { - throw new Error('Missing CSRF token in form data') + throw new OAuthError('invalid_request', 'Missing required form token', 400) } const cookieHeader = request.headers.get('Cookie') || '' @@ -567,17 +567,22 @@ export async function parseRedirectApproval( const tokenFromCookie = csrfCookie ? csrfCookie.substring('__Host-CSRF_TOKEN='.length) : null if (!tokenFromCookie || tokenFromForm !== tokenFromCookie) { - throw new Error('CSRF token mismatch') + throw new OAuthError('access_denied', 'Request validation failed', 403) } const encodedState = formData.get('state') if (!encodedState || typeof encodedState !== 'string') { - throw new Error('Missing state in form data') + throw new OAuthError('invalid_request', 'Missing state in form data', 400) } - const state = JSON.parse(atob(encodedState)) + let state: { oauthReqInfo?: AuthRequest } + try { + state = JSON.parse(atob(encodedState)) + } catch { + throw new OAuthError('invalid_request', 'Invalid state encoding', 400) + } if (!state.oauthReqInfo || !state.oauthReqInfo.clientId) { - throw new Error('Invalid state data') + throw new OAuthError('invalid_request', 'Invalid state data', 400) } const existingApprovedClients = @@ -694,7 +699,7 @@ export async function validateOAuthState( const stateFromQuery = url.searchParams.get('state') if (!stateFromQuery) { - throw new Error('Missing state parameter') + throw new OAuthError('invalid_request', 'Missing state parameter', 400) } // Decode the state parameter to extract the embedded stateToken @@ -703,15 +708,16 @@ export async function validateOAuthState( const decodedState = JSON.parse(atob(stateFromQuery)) stateToken = decodedState.state if (!stateToken) { - throw new Error('State token not found in decoded state') + throw new OAuthError('invalid_request', 'State token not found in decoded state', 400) } } catch (e) { - throw new Error('Failed to decode state parameter') + if (e instanceof OAuthError) throw e + throw new OAuthError('invalid_request', 'Failed to decode state parameter', 400) } const storedDataJson = await kv.get(`oauth:state:${stateToken}`) if (!storedDataJson) { - throw new Error('Invalid or expired state') + throw new OAuthError('invalid_request', 'Invalid or expired state', 400) } const cookieHeader = request.headers.get('Cookie') || '' @@ -722,7 +728,11 @@ export async function validateOAuthState( : null if (!consentedStateHash) { - throw new Error('Missing session binding cookie - authorization flow must be restarted') + throw new OAuthError( + 'invalid_request', + 'Authorization session expired, please restart the flow', + 400 + ) } const encoder = new TextEncoder() @@ -732,7 +742,7 @@ export async function validateOAuthState( const stateHash = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('') if (stateHash !== consentedStateHash) { - throw new Error('State token does not match session - possible CSRF attack detected') + throw new OAuthError('access_denied', 'Session validation failed', 403) } // Parse and validate stored OAuth state data @@ -751,7 +761,7 @@ export async function validateOAuthState( const parseResult = StoredOAuthStateSchema.safeParse(JSON.parse(storedDataJson)) if (!parseResult.success) { - throw new Error('Invalid OAuth state data format - PKCE security violation') + throw new OAuthError('invalid_request', 'Invalid authorization state', 400) } await kv.delete(`oauth:state:${stateToken}`) diff --git a/packages/mcp-common/vitest.config.ts b/packages/mcp-common/vitest.config.ts index 601865c4..95179ad7 100644 --- a/packages/mcp-common/vitest.config.ts +++ b/packages/mcp-common/vitest.config.ts @@ -1,3 +1,4 @@ +import path from 'node:path' import { defineWorkersProject } from '@cloudflare/vitest-pool-workers/config' export interface TestEnv { @@ -11,14 +12,22 @@ export default defineWorkersProject({ workers: { singleWorker: true, miniflare: { - compatibilityDate: '2025-03-10', + compatibilityDate: '2026-03-09', compatibilityFlags: ['nodejs_compat'], bindings: { CLOUDFLARE_MOCK_ACCOUNT_ID: 'mock-account-id', CLOUDFLARE_MOCK_API_TOKEN: 'mock-api-token', - } satisfies Partial, + DEV_DISABLE_OAUTH: false, + }, }, }, }, }, + resolve: { + alias: { + // The real cloudflare SDK imports ReadStream from node:fs which is unavailable in workerd. + // Alias to a lightweight mock that provides Cloudflare and APIError classes. + cloudflare: path.resolve(__dirname, 'src/__mocks__/cloudflare.ts'), + }, + }, })