From 39d33337199acd6e763d90284796371cc022629e Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 11 Oct 2024 18:08:18 +0530 Subject: [PATCH 1/3] chore: add plugin options interface --- src/middlewares/hooks/types.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/middlewares/hooks/types.ts b/src/middlewares/hooks/types.ts index e6ff1a100..a2cbd9f74 100644 --- a/src/middlewares/hooks/types.ts +++ b/src/middlewares/hooks/types.ts @@ -1,6 +1,7 @@ export interface Check { id: string; parameters: object; + is_enabled?: boolean; } export interface HookOnFailObject { @@ -86,3 +87,7 @@ export interface GuardrailResult { export type HookResult = GuardrailResult; export type EventType = 'beforeRequestHook' | 'afterRequestHook'; + +export interface HandlerOptions { + env: Record; +} From 0123c122025512498cf1714a2323aea05b1b4c36 Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 11 Oct 2024 18:11:13 +0530 Subject: [PATCH 2/3] chore: pass env to plugin handlers --- plugins/portkey/gibberish.ts | 4 +++- plugins/portkey/language.ts | 4 +++- plugins/portkey/moderateContent.ts | 4 +++- plugins/portkey/pii.ts | 14 ++++++++++---- plugins/types.ts | 5 ++++- src/handlers/handlerUtils.ts | 8 +++++--- src/handlers/responseHandlers.ts | 9 ++++++--- src/middlewares/hooks/index.ts | 30 ++++++++++++++++++++++-------- 8 files changed, 56 insertions(+), 22 deletions(-) diff --git a/plugins/portkey/gibberish.ts b/plugins/portkey/gibberish.ts index 7367cefae..fd39d7c33 100644 --- a/plugins/portkey/gibberish.ts +++ b/plugins/portkey/gibberish.ts @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals'; export const handler: PluginHandler = async ( context: PluginContext, parameters: PluginParameters, - eventType: HookEventType + eventType: HookEventType, + options ) => { let error = null; let verdict = false; @@ -22,6 +23,7 @@ export const handler: PluginHandler = async ( // Check if the text is gibberish const response: any = await fetchPortkey( + options.env, PORTKEY_ENDPOINTS.GIBBERISH, parameters.credentials, { input: text } diff --git a/plugins/portkey/language.ts b/plugins/portkey/language.ts index 4139088b0..9c81c4a7d 100644 --- a/plugins/portkey/language.ts +++ b/plugins/portkey/language.ts @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals'; export const handler: PluginHandler = async ( context: PluginContext, parameters: PluginParameters, - eventType: HookEventType + eventType: HookEventType, + options ) => { let error = null; let verdict = false; @@ -23,6 +24,7 @@ export const handler: PluginHandler = async ( // Find the language of the text const result: any = await fetchPortkey( + options.env, PORTKEY_ENDPOINTS.LANGUAGE, parameters.credentials, { input: text } diff --git a/plugins/portkey/moderateContent.ts b/plugins/portkey/moderateContent.ts index f02c45f23..ca9062adf 100644 --- a/plugins/portkey/moderateContent.ts +++ b/plugins/portkey/moderateContent.ts @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals'; export const handler: PluginHandler = async ( context: PluginContext, parameters: PluginParameters, - eventType: HookEventType + eventType: HookEventType, + options ) => { let error = null; let verdict = false; @@ -23,6 +24,7 @@ export const handler: PluginHandler = async ( // Get data from the relevant tool const result: any = await fetchPortkey( + options.env, PORTKEY_ENDPOINTS.MODERATIONS, parameters.credentials, { input: text } diff --git a/plugins/portkey/pii.ts b/plugins/portkey/pii.ts index 9e809aacc..70e8685b4 100644 --- a/plugins/portkey/pii.ts +++ b/plugins/portkey/pii.ts @@ -7,8 +7,12 @@ import { import { getText } from '../utils'; import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals'; -async function detectPII(text: string, credentials: any) { - const result = await fetchPortkey(PORTKEY_ENDPOINTS.PII, credentials, { +async function detectPII( + text: string, + credentials: any, + env: Record +) { + const result = await fetchPortkey(env, PORTKEY_ENDPOINTS.PII, credentials, { input: text, }); @@ -36,7 +40,8 @@ async function detectPII(text: string, credentials: any) { export const handler: PluginHandler = async ( context: PluginContext, parameters: PluginParameters, - eventType: HookEventType + eventType: HookEventType, + options ) => { let error = null; let verdict = false; @@ -49,7 +54,8 @@ export const handler: PluginHandler = async ( let { detectedPIICategories, PIIData } = await detectPII( text, - parameters.credentials + parameters.credentials, + options.env ); // Filter the detected categories based on the categories to check diff --git a/plugins/types.ts b/plugins/types.ts index 881eb1839..c4f37a7f5 100644 --- a/plugins/types.ts +++ b/plugins/types.ts @@ -19,5 +19,8 @@ export type HookEventType = 'beforeRequestHook' | 'afterRequestHook'; export type PluginHandler = ( context: PluginContext, parameters: PluginParameters, - eventType: HookEventType + eventType: HookEventType, + options: { + env: Record; + } ) => Promise; diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 58ad67256..02f36db1a 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -1328,9 +1328,11 @@ export async function beforeRequestHookHandler( ): Promise { try { const hooksManager = c.get('hooksManager'); - const hooksResult = await hooksManager.executeHooks(hookSpanId, [ - 'syncBeforeRequestHook', - ]); + const hooksResult = await hooksManager.executeHooks( + hookSpanId, + ['syncBeforeRequestHook'], + { env: env(c) } + ); if (hooksResult.shouldDeny) { return new Response( diff --git a/src/handlers/responseHandlers.ts b/src/handlers/responseHandlers.ts index b5875c15f..15db82bca 100644 --- a/src/handlers/responseHandlers.ts +++ b/src/handlers/responseHandlers.ts @@ -15,6 +15,7 @@ import { handleTextResponse, } from './streamHandler'; import { HookSpan } from '../middlewares/hooks'; +import { env } from 'hono/adapter'; /** * Handles various types of responses based on the specified parameters @@ -162,9 +163,11 @@ export async function afterRequestHookHandler( hooksManager.getSpan(hookSpanId).resetHookResult('afterRequestHook'); } - let { shouldDeny, results } = await hooksManager.executeHooks(hookSpanId, [ - 'syncAfterRequestHook', - ]); + let { shouldDeny, results } = await hooksManager.executeHooks( + hookSpanId, + ['syncAfterRequestHook'], + { env: env(c) } + ); if (!responseJSON) { return response; diff --git a/src/middlewares/hooks/index.ts b/src/middlewares/hooks/index.ts index 2b4afe66c..bfdca842a 100644 --- a/src/middlewares/hooks/index.ts +++ b/src/middlewares/hooks/index.ts @@ -8,6 +8,7 @@ import { HookOnFailObject, HookOnSuccessObject, HookResult, + HandlerOptions, } from './types'; import { plugins } from '../../../plugins'; import { Context } from 'hono'; @@ -213,7 +214,8 @@ export class HooksManager { public async executeHooks( spanId: string, - eventTypePresets: string[] + eventTypePresets: string[], + options: HandlerOptions ): Promise<{ results: HookResult[]; shouldDeny: boolean }> { const span = this.getSpan(spanId); @@ -225,7 +227,9 @@ export class HooksManager { try { const results = await Promise.all( - hooksToExecute.map((hook) => this.executeEachHook(spanId, hook)) + hooksToExecute.map((hook) => + this.executeEachHook(spanId, hook, options) + ) ); const shouldDeny = results.some( (result, index) => @@ -247,7 +251,8 @@ export class HooksManager { private async executeFunction( context: HookSpanContext, check: Check, - eventType: EventType + eventType: EventType, + options: HandlerOptions ): Promise { const [source, fn] = check.id.split('.'); const createdAt = new Date(); @@ -255,7 +260,8 @@ export class HooksManager { const result = await this.plugins[source][fn]( context, check.parameters, - eventType + eventType, + options ); return { ...result, @@ -284,7 +290,8 @@ export class HooksManager { private async executeEachHook( spanId: string, - hook: HookObject + hook: HookObject, + options: HandlerOptions ): Promise { const span = this.getSpan(spanId); let hookResult: HookResult = { id: hook.id } as HookResult; @@ -296,9 +303,16 @@ export class HooksManager { if (hook.type === 'guardrail' && hook.checks) { const checkResults = await Promise.all( - hook.checks.map((check: Check) => - this.executeFunction(span.getContext(), check, hook.eventType) - ) + hook.checks + .filter((check: Check) => check.is_enabled !== false) + .map((check: Check) => + this.executeFunction( + span.getContext(), + check, + hook.eventType, + options + ) + ) ); hookResult = { From b1cf8725d01ba75d43d4e4310f3a2d1aded5ff58 Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 11 Oct 2024 18:12:44 +0530 Subject: [PATCH 3/3] feat: add optional service binding logic for portkey plugins --- plugins/portkey/globals.ts | 19 +++++++-- plugins/utils.ts | 83 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/plugins/portkey/globals.ts b/plugins/portkey/globals.ts index c78a9238f..6a7d4af22 100644 --- a/plugins/portkey/globals.ts +++ b/plugins/portkey/globals.ts @@ -1,4 +1,7 @@ -import { post } from '../utils'; +import { getRuntimeKey } from 'hono/adapter'; +import { post, postWithCloudflareServiceBinding } from '../utils'; + +export const BASE_URL = 'https://api.portkey.ai/v1/execute-guardrails'; export const PORTKEY_ENDPOINTS = { MODERATIONS: '/moderations', @@ -8,15 +11,25 @@ export const PORTKEY_ENDPOINTS = { }; export const fetchPortkey = async ( + env: Record, endpoint: string, credentials: any, data: any ) => { const options = { headers: { - Authorization: `Bearer ${credentials.apiKey}`, + 'x-portkey-api-key': credentials.apiKey, }, }; - return post(`${credentials.baseURL}${endpoint}`, data, options); + if (getRuntimeKey() === 'workerd' && env.portkeyGuardrails) { + return postWithCloudflareServiceBinding( + `${BASE_URL}${endpoint}`, + data, + env.portkeyGuardrails, + options + ); + } + + return post(`${BASE_URL}${endpoint}`, data, options); }; diff --git a/plugins/utils.ts b/plugins/utils.ts index 986cb9728..8ca826e7d 100644 --- a/plugins/utils.ts +++ b/plugins/utils.ts @@ -129,3 +129,86 @@ export async function post( throw error; } } + +/** + * Sends a POST request to the specified URL with the given data and timeout. + * @param url - The URL to send the POST request to. + * @param data - The data to be sent in the request body. + * @param options - Additional options for the fetch call. + * @param timeout - Timeout in milliseconds (default: 5 seconds). + * @returns A promise that resolves to the JSON response. + * @throws {HttpError} Throws an HttpError with detailed information if the request fails. + * @throws {Error} Throws a generic Error for network issues or timeouts. + */ +export async function postWithCloudflareServiceBinding( + url: string, + data: any, + serviceBinding: any, + options: PostOptions = {}, + timeout: number = 5000 +): Promise { + const defaultOptions: PostOptions = { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(data), + }; + + const mergedOptions: PostOptions = { ...defaultOptions, ...options }; + + if (mergedOptions.headers) { + mergedOptions.headers = { + ...defaultOptions.headers, + ...mergedOptions.headers, + }; + } + + try { + const controller = new AbortController(); + const id = setTimeout(() => controller.abort(), timeout); + + const response: Response = await serviceBinding.fetch(url, { + ...mergedOptions, + signal: controller.signal, + }); + + clearTimeout(id); + + if (!response.ok) { + let errorBody: string; + try { + errorBody = await response.text(); + } catch (e) { + errorBody = 'Unable to retrieve response body'; + } + + const errorResponse: ErrorResponse = { + status: response.status, + statusText: response.statusText, + body: errorBody, + }; + + throw new HttpError( + `HTTP error! status: ${response.status}`, + errorResponse + ); + } + + return (await response.json()) as T; + } catch (error: any) { + if (error instanceof HttpError) { + throw error; + } + if (error.name === 'AbortError') { + throw new TimeoutError( + `Request timed out after ${timeout}ms`, + url, + timeout, + mergedOptions.method || 'POST' + ); + } + // console.error('Error in post request:', error); + throw error; + } +}