diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index fb031e71d..ab4cc63de 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -290,7 +290,8 @@ export async function tryPost( requestHeaders: Record, fn: endpointStrings, currentIndex: number | string, - method: string = 'POST' + method: string = 'POST', + abortSignal?: AbortSignal ): Promise { const requestContext = new RequestContext( c, @@ -301,6 +302,9 @@ export async function tryPost( method, currentIndex as number ); + if (abortSignal) { + requestContext.setAbortSignal(abortSignal); + } const hooksService = new HooksService(requestContext); const providerContext = new ProviderContext(requestContext.provider); const logsService = new LogsService(c); @@ -479,7 +483,8 @@ export async function tryTargetsRecursively( fn: endpointStrings, method: string, jsonPath: string, - inheritedConfig: Record = {} + inheritedConfig: Record = {}, + abortSignal?: AbortSignal ): Promise { const currentTarget: any = { ...targetGroup }; let currentJsonPath = jsonPath; @@ -669,7 +674,8 @@ export async function tryTargetsRecursively( fn, method, `${currentJsonPath}.targets[${originalIndex}]`, - currentInheritedConfig + currentInheritedConfig, + abortSignal ); const codes = currentTarget.strategy?.onStatusCodes; const gatewayException = @@ -712,7 +718,8 @@ export async function tryTargetsRecursively( fn, method, currentJsonPath, - currentInheritedConfig + currentInheritedConfig, + abortSignal ); break; } @@ -720,6 +727,122 @@ export async function tryTargetsRecursively( } break; + case StrategyModes.SAMPLE: { + const targets = currentTarget.targets || []; + const onStatusCodes = currentTarget.strategy?.onStatusCodes; + const cancelOthers = currentTarget.strategy?.cancelOthers; + + // v1 limitation: do not support sampling when request body is a ReadableStream + if (request instanceof ReadableStream) { + response = new Response( + JSON.stringify({ + status: 'failure', + message: + 'Strategy "sample" does not support streaming request bodies in v1', + }), + { status: 400, headers: { 'content-type': 'application/json' } } + ); + break; + } + + // Fire all requests in parallel; pick first-success + let winnerResolved = false; + let resolveWinner: (value: Response) => void = () => {}; + const winnerPromise = new Promise((resolve) => { + resolveWinner = resolve; + }); + + const controllers: AbortController[] = []; + const pendingPromises: Array< + Promise<{ resp: Response; idx: number; abort: AbortController }> + > = targets.map((t: Targets, index: number) => { + const originalIndex = (t.originalIndex as number | undefined) ?? index; + const controller = new AbortController(); + controllers.push(controller); + return tryTargetsRecursively( + c, + t, + request, + requestHeaders, + fn, + method, + `${currentJsonPath}.targets[${originalIndex}]`, + currentInheritedConfig, + controller.signal + ).then((resp) => ({ resp, idx: originalIndex, abort: controller })); + }); + + // Resolve on first-success + for (const p of pendingPromises) { + p.then(({ resp, abort }) => { + if (winnerResolved) return; + const gatewayException = + resp?.headers.get('x-portkey-gateway-exception') === 'true'; + const isSuccess = + (Array.isArray(onStatusCodes) && + !onStatusCodes.includes(resp?.status)) || + (!onStatusCodes && resp?.ok) || + gatewayException; + if (isSuccess && !winnerResolved) { + winnerResolved = true; + resolveWinner(resp); + if (cancelOthers) { + for (const ctl of controllers) { + try { + ctl.abort(); + } catch {} + } + } + } + }).catch(() => { + // Ignore individual errors; overall fallback handled below + }); + } + + // If none succeed, return the last completed response + (async () => { + const results = await Promise.allSettled(pendingPromises); + if (winnerResolved) return; + const fulfilled = results.filter( + ( + r + ): r is PromiseFulfilledResult<{ + resp: Response; + idx: number; + abort: AbortController; + }> => r.status === 'fulfilled' + ); + if (fulfilled.length) { + const { resp } = fulfilled[fulfilled.length - 1].value; + winnerResolved = true; + resolveWinner(resp); + if (cancelOthers) { + for (const ctl of controllers) { + try { + ctl.abort(); + } catch {} + } + } + } else { + // If all rejected (shouldn't generally happen because tryTargetsRecursively guards), pick a generic 500 + winnerResolved = true; + resolveWinner( + new Response( + JSON.stringify({ + status: 'failure', + message: 'All sample targets failed', + }), + { status: 500, headers: { 'content-type': 'application/json' } } + ) + ); + } + })(); + + response = await winnerPromise; + // Note: cancelOthers is a no-op for now; underlying fetch cancellation will be wired in a later update + break; + } + case StrategyModes.CONDITIONAL: { let metadata: Record; try { @@ -757,7 +880,8 @@ export async function tryTargetsRecursively( fn, method, `${currentJsonPath}.targets[${originalIndex}]`, - currentInheritedConfig + currentInheritedConfig, + abortSignal ); break; } @@ -772,7 +896,8 @@ export async function tryTargetsRecursively( fn, method, `${currentJsonPath}.targets[${originalIndex}]`, - currentInheritedConfig + currentInheritedConfig, + abortSignal ); break; @@ -785,7 +910,8 @@ export async function tryTargetsRecursively( requestHeaders, fn, currentJsonPath, - method + method, + abortSignal ); if (isHandlingCircuitBreaker) { await c.get('handleCircuitBreakerResponse')?.( @@ -1179,7 +1305,8 @@ export async function recursiveAfterRequestHookHandler( retry.onStatusCodes, requestTimeout, requestHandler, - retry.useRetryAfterHeader + retry.useRetryAfterHeader, + requestContext.abortSignal )); // Check if sync hooks are available diff --git a/src/handlers/retryHandler.ts b/src/handlers/retryHandler.ts index f6af17318..fecf3e96e 100644 --- a/src/handlers/retryHandler.ts +++ b/src/handlers/retryHandler.ts @@ -5,10 +5,18 @@ async function fetchWithTimeout( url: string, options: RequestInit, timeout: number, - requestHandler?: () => Promise + requestHandler?: () => Promise, + externalAbortSignal?: AbortSignal ) { const controller = new AbortController(); const timeoutId = setTimeout(() => controller.abort(), timeout); + if (externalAbortSignal) { + if (externalAbortSignal.aborted) { + controller.abort(); + } else { + externalAbortSignal.addEventListener('abort', () => controller.abort()); + } + } const timeoutRequestOptions = { ...options, signal: controller.signal, @@ -69,7 +77,8 @@ export const retryRequest = async ( statusCodesToRetry: number[], timeout: number | null, requestHandler?: () => Promise, - followProviderRetry?: boolean + followProviderRetry?: boolean, + externalAbortSignal?: AbortSignal ): Promise<{ response: Response; attempt: number | undefined; @@ -93,12 +102,17 @@ export const retryRequest = async ( url, options, timeout, - requestHandler + requestHandler, + externalAbortSignal ); } else if (requestHandler) { response = await requestHandler(); } else { - response = await fetch(url, options); + const noTimeoutOptions = { ...options } as RequestInit; + if (externalAbortSignal) { + noTimeoutOptions.signal = externalAbortSignal; + } + response = await fetch(url, noTimeoutOptions); } if (statusCodesToRetry.includes(response.status)) { const errorObj: any = new Error(await response.text()); diff --git a/src/handlers/services/requestContext.ts b/src/handlers/services/requestContext.ts index e389e79fc..0646fca14 100644 --- a/src/handlers/services/requestContext.ts +++ b/src/handlers/services/requestContext.ts @@ -18,6 +18,7 @@ export class RequestContext { private _transformedRequestBody: any; public readonly providerOption: Options; private _requestURL: string = ''; // Is set at the beginning of tryPost() + private _externalAbortSignal: AbortSignal | undefined; constructor( public readonly honoContext: Context, @@ -44,6 +45,14 @@ export class RequestContext { this._requestURL = requestURL; } + get abortSignal(): AbortSignal | undefined { + return this._externalAbortSignal; + } + + setAbortSignal(signal: AbortSignal) { + this._externalAbortSignal = signal; + } + get overrideParams(): Params { return this.providerOption?.overrideParams ?? {}; } diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index 91e95b2ea..09697b888 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -16,15 +16,20 @@ export const configSchema: any = z .string() .refine( (value) => - ['single', 'loadbalance', 'fallback', 'conditional'].includes( - value - ), + [ + 'single', + 'loadbalance', + 'fallback', + 'conditional', + 'sample', + ].includes(value), { message: - "Invalid 'mode' value. Must be one of: single, loadbalance, fallback, conditional", + "Invalid 'mode' value. Must be one of: single, loadbalance, fallback, conditional, sample", } ), on_status_codes: z.array(z.number()).optional(), + cancel_others: z.boolean().optional(), conditions: z .array( z.object({ diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 452e2b83a..997fe7a4d 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -24,11 +24,13 @@ export enum StrategyModes { FALLBACK = 'fallback', SINGLE = 'single', CONDITIONAL = 'conditional', + SAMPLE = 'sample', } interface Strategy { mode: StrategyModes; onStatusCodes?: Array; + cancelOthers?: boolean; conditions?: { query: { [key: string]: any;