Skip to content

Commit

Permalink
Merge pull request #677 from Portkey-AI/feat/portley-llm-guardrails
Browse files Browse the repository at this point in the history
Feat: portley llm guardrails update
  • Loading branch information
VisargD authored Oct 11, 2024
2 parents 8404e3a + b1cf872 commit b64b8a2
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 25 deletions.
4 changes: 3 additions & 1 deletion plugins/portkey/gibberish.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals';
export const handler: PluginHandler = async (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType
eventType: HookEventType,
options
) => {
let error = null;
let verdict = false;
Expand All @@ -22,6 +23,7 @@ export const handler: PluginHandler = async (

// Check if the text is gibberish
const response: any = await fetchPortkey(
options.env,
PORTKEY_ENDPOINTS.GIBBERISH,
parameters.credentials,
{ input: text }
Expand Down
19 changes: 16 additions & 3 deletions plugins/portkey/globals.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { post } from '../utils';
import { getRuntimeKey } from 'hono/adapter';
import { post, postWithCloudflareServiceBinding } from '../utils';

export const BASE_URL = 'https://api.portkey.ai/v1/execute-guardrails';

export const PORTKEY_ENDPOINTS = {
MODERATIONS: '/moderations',
Expand All @@ -8,15 +11,25 @@ export const PORTKEY_ENDPOINTS = {
};

export const fetchPortkey = async (
env: Record<string, any>,
endpoint: string,
credentials: any,
data: any
) => {
const options = {
headers: {
Authorization: `Bearer ${credentials.apiKey}`,
'x-portkey-api-key': credentials.apiKey,
},
};

return post(`${credentials.baseURL}${endpoint}`, data, options);
if (getRuntimeKey() === 'workerd' && env.portkeyGuardrails) {
return postWithCloudflareServiceBinding(
`${BASE_URL}${endpoint}`,
data,
env.portkeyGuardrails,
options
);
}

return post(`${BASE_URL}${endpoint}`, data, options);
};
4 changes: 3 additions & 1 deletion plugins/portkey/language.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals';
export const handler: PluginHandler = async (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType
eventType: HookEventType,
options
) => {
let error = null;
let verdict = false;
Expand All @@ -23,6 +24,7 @@ export const handler: PluginHandler = async (

// Find the language of the text
const result: any = await fetchPortkey(
options.env,
PORTKEY_ENDPOINTS.LANGUAGE,
parameters.credentials,
{ input: text }
Expand Down
4 changes: 3 additions & 1 deletion plugins/portkey/moderateContent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals';
export const handler: PluginHandler = async (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType
eventType: HookEventType,
options
) => {
let error = null;
let verdict = false;
Expand All @@ -23,6 +24,7 @@ export const handler: PluginHandler = async (

// Get data from the relevant tool
const result: any = await fetchPortkey(
options.env,
PORTKEY_ENDPOINTS.MODERATIONS,
parameters.credentials,
{ input: text }
Expand Down
14 changes: 10 additions & 4 deletions plugins/portkey/pii.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import {
import { getText } from '../utils';
import { PORTKEY_ENDPOINTS, fetchPortkey } from './globals';

async function detectPII(text: string, credentials: any) {
const result = await fetchPortkey(PORTKEY_ENDPOINTS.PII, credentials, {
async function detectPII(
text: string,
credentials: any,
env: Record<string, any>
) {
const result = await fetchPortkey(env, PORTKEY_ENDPOINTS.PII, credentials, {
input: text,
});

Expand Down Expand Up @@ -36,7 +40,8 @@ async function detectPII(text: string, credentials: any) {
export const handler: PluginHandler = async (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType
eventType: HookEventType,
options
) => {
let error = null;
let verdict = false;
Expand All @@ -49,7 +54,8 @@ export const handler: PluginHandler = async (

let { detectedPIICategories, PIIData } = await detectPII(
text,
parameters.credentials
parameters.credentials,
options.env
);

// Filter the detected categories based on the categories to check
Expand Down
5 changes: 4 additions & 1 deletion plugins/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ export type HookEventType = 'beforeRequestHook' | 'afterRequestHook';
export type PluginHandler = (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType
eventType: HookEventType,
options: {
env: Record<string, any>;
}
) => Promise<PluginHandlerResponse>;
83 changes: 83 additions & 0 deletions plugins/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,86 @@ export async function post<T = any>(
throw error;
}
}

/**
* Sends a POST request to the specified URL with the given data and timeout.
* @param url - The URL to send the POST request to.
* @param data - The data to be sent in the request body.
* @param options - Additional options for the fetch call.
* @param timeout - Timeout in milliseconds (default: 5 seconds).
* @returns A promise that resolves to the JSON response.
* @throws {HttpError} Throws an HttpError with detailed information if the request fails.
* @throws {Error} Throws a generic Error for network issues or timeouts.
*/
export async function postWithCloudflareServiceBinding<T = any>(
url: string,
data: any,
serviceBinding: any,
options: PostOptions = {},
timeout: number = 5000
): Promise<T> {
const defaultOptions: PostOptions = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(data),
};

const mergedOptions: PostOptions = { ...defaultOptions, ...options };

if (mergedOptions.headers) {
mergedOptions.headers = {
...defaultOptions.headers,
...mergedOptions.headers,
};
}

try {
const controller = new AbortController();
const id = setTimeout(() => controller.abort(), timeout);

const response: Response = await serviceBinding.fetch(url, {
...mergedOptions,
signal: controller.signal,
});

clearTimeout(id);

if (!response.ok) {
let errorBody: string;
try {
errorBody = await response.text();
} catch (e) {
errorBody = 'Unable to retrieve response body';
}

const errorResponse: ErrorResponse = {
status: response.status,
statusText: response.statusText,
body: errorBody,
};

throw new HttpError(
`HTTP error! status: ${response.status}`,
errorResponse
);
}

return (await response.json()) as T;
} catch (error: any) {
if (error instanceof HttpError) {
throw error;
}
if (error.name === 'AbortError') {
throw new TimeoutError(
`Request timed out after ${timeout}ms`,
url,
timeout,
mergedOptions.method || 'POST'
);
}
// console.error('Error in post request:', error);
throw error;
}
}
8 changes: 5 additions & 3 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1328,9 +1328,11 @@ export async function beforeRequestHookHandler(
): Promise<any> {
try {
const hooksManager = c.get('hooksManager');
const hooksResult = await hooksManager.executeHooks(hookSpanId, [
'syncBeforeRequestHook',
]);
const hooksResult = await hooksManager.executeHooks(
hookSpanId,
['syncBeforeRequestHook'],
{ env: env(c) }
);

if (hooksResult.shouldDeny) {
return new Response(
Expand Down
9 changes: 6 additions & 3 deletions src/handlers/responseHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
handleTextResponse,
} from './streamHandler';
import { HookSpan } from '../middlewares/hooks';
import { env } from 'hono/adapter';

/**
* Handles various types of responses based on the specified parameters
Expand Down Expand Up @@ -162,9 +163,11 @@ export async function afterRequestHookHandler(
hooksManager.getSpan(hookSpanId).resetHookResult('afterRequestHook');
}

let { shouldDeny, results } = await hooksManager.executeHooks(hookSpanId, [
'syncAfterRequestHook',
]);
let { shouldDeny, results } = await hooksManager.executeHooks(
hookSpanId,
['syncAfterRequestHook'],
{ env: env(c) }
);

if (!responseJSON) {
return response;
Expand Down
30 changes: 22 additions & 8 deletions src/middlewares/hooks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
HookOnFailObject,
HookOnSuccessObject,
HookResult,
HandlerOptions,
} from './types';
import { plugins } from '../../../plugins';
import { Context } from 'hono';
Expand Down Expand Up @@ -213,7 +214,8 @@ export class HooksManager {

public async executeHooks(
spanId: string,
eventTypePresets: string[]
eventTypePresets: string[],
options: HandlerOptions
): Promise<{ results: HookResult[]; shouldDeny: boolean }> {
const span = this.getSpan(spanId);

Expand All @@ -225,7 +227,9 @@ export class HooksManager {

try {
const results = await Promise.all(
hooksToExecute.map((hook) => this.executeEachHook(spanId, hook))
hooksToExecute.map((hook) =>
this.executeEachHook(spanId, hook, options)
)
);
const shouldDeny = results.some(
(result, index) =>
Expand All @@ -247,15 +251,17 @@ export class HooksManager {
private async executeFunction(
context: HookSpanContext,
check: Check,
eventType: EventType
eventType: EventType,
options: HandlerOptions
): Promise<GuardrailCheckResult> {
const [source, fn] = check.id.split('.');
const createdAt = new Date();
try {
const result = await this.plugins[source][fn](
context,
check.parameters,
eventType
eventType,
options
);
return {
...result,
Expand Down Expand Up @@ -284,7 +290,8 @@ export class HooksManager {

private async executeEachHook(
spanId: string,
hook: HookObject
hook: HookObject,
options: HandlerOptions
): Promise<HookResult> {
const span = this.getSpan(spanId);
let hookResult: HookResult = { id: hook.id } as HookResult;
Expand All @@ -296,9 +303,16 @@ export class HooksManager {

if (hook.type === 'guardrail' && hook.checks) {
const checkResults = await Promise.all(
hook.checks.map((check: Check) =>
this.executeFunction(span.getContext(), check, hook.eventType)
)
hook.checks
.filter((check: Check) => check.is_enabled !== false)
.map((check: Check) =>
this.executeFunction(
span.getContext(),
check,
hook.eventType,
options
)
)
);

hookResult = {
Expand Down
5 changes: 5 additions & 0 deletions src/middlewares/hooks/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export interface Check {
id: string;
parameters: object;
is_enabled?: boolean;
}

export interface HookOnFailObject {
Expand Down Expand Up @@ -86,3 +87,7 @@ export interface GuardrailResult {
export type HookResult = GuardrailResult;

export type EventType = 'beforeRequestHook' | 'afterRequestHook';

export interface HandlerOptions {
env: Record<string, any>;
}

0 comments on commit b64b8a2

Please sign in to comment.