diff --git a/plugins/index.ts b/plugins/index.ts index 9265a7d2b..5b59d8ac4 100644 --- a/plugins/index.ts +++ b/plugins/index.ts @@ -32,6 +32,7 @@ import { handler as patronusnoRacialBias } from './patronus/noRacialBias'; import { handler as patronusretrievalAnswerRelevance } from './patronus/retrievalAnswerRelevance'; import { handler as patronustoxicity } from './patronus/toxicity'; import { handler as patronuscustom } from './patronus/custom'; +import { mistralGuardrailHandler } from './mistral'; import { handler as pangeatextGuard } from './pangea/textGuard'; export const plugins = { @@ -81,6 +82,9 @@ export const plugins = { toxicity: patronustoxicity, custom: patronuscustom, }, + mistral: { + moderateContent: mistralGuardrailHandler, + }, pangea: { textGuard: pangeatextGuard, }, diff --git a/plugins/mistral/index.ts b/plugins/mistral/index.ts new file mode 100644 index 000000000..2c93f6992 --- /dev/null +++ b/plugins/mistral/index.ts @@ -0,0 +1,133 @@ +import { + HookEventType, + PluginContext, + PluginHandler, + PluginParameters, +} from '../types'; +import { getText, post } from '../utils'; + +interface MistralResponse { + id: string; + model: string; + results: [ + { + categories: { + sexual: boolean; + hate_and_discrimination: boolean; + violence_and_threats: boolean; + dangerous_and_criminal_content: boolean; + selfharm: boolean; + health: boolean; + financial: boolean; + law: boolean; + pii: boolean; + }; + category_score: { + sexual: number; + hate_and_discrimination: number; + violence_and_threats: number; + dangerous_and_criminal_content: number; + selfharm: number; + health: number; + financial: number; + law: number; + pii: number; + }; + }, + ]; +} + +type GuardrailFunction = keyof MistralResponse['results'][0]['categories']; + +export const mistralGuardrailHandler: PluginHandler = async ( + context: PluginContext, + parameters: PluginParameters, + eventType: HookEventType, + _options +) => { + let error = null; + let verdict = true; + let data = null; + + const creds = parameters.credentials as Record; + if (!creds.apiKey) { + return { + error: 'Mistral API key not provided.', + verdict: false, + data: null, + }; + } + + let model = 'mistral-moderation-latest'; + + if (parameters.model) { + // Model can be passed dynamically + model = parameters.model; + } + + const checks = parameters.categories as GuardrailFunction[]; + + const text = getText(context, eventType); + const messages = + eventType === 'beforeRequestHook' + ? context.request?.json?.messages + : context.response?.json?.messages; + + // should contain text or should contain messages array + if ( + (!text && !Array.isArray(messages)) || + (Array.isArray(messages) && messages.length === 0) + ) { + return { + error: 'Mistral: Invalid Request body', + verdict: false, + data: null, + }; + } + + // Use conversation guardrail if it's a chatcomplete and before hook + const shouldUseConversation = + eventType === 'beforeRequestHook' && context.requestType === 'chatComplete'; + const url = shouldUseConversation + ? 'https://api.mistral.ai/v1/chat/moderations' + : 'https://api.mistral.ai/v1/moderations'; + + try { + const request = await post( + url, + { + model: model, + ...(!shouldUseConversation && { input: [text] }), + ...(shouldUseConversation && { input: [messages] }), + }, + { + headers: { + Authorization: `Bearer ${creds.apiKey}`, + 'Content-Type': 'application/json', + }, + } + ); + + const categories: Record = + request.results[0]?.categories ?? {}; + const categoriesFlagged = Object.keys(categories).filter((category) => { + if ( + checks.includes(category as GuardrailFunction) && + !!categories[category as GuardrailFunction] + ) { + return true; + } + return false; + }); + + if (categoriesFlagged.length > 0) { + verdict = false; + data = { flagged_categories: categoriesFlagged }; + } + } catch (err) { + error = err; + verdict = true; + } + + return { error, verdict, data }; +}; diff --git a/plugins/mistral/manifest.json b/plugins/mistral/manifest.json new file mode 100644 index 000000000..965e6e958 --- /dev/null +++ b/plugins/mistral/manifest.json @@ -0,0 +1,66 @@ +{ + "id": "mistral", + "description": "Mistral Content Moderation classifier leverages the most relevant policy categories for effective guardrails and introduces a pragmatic approach to LLM safety by addressing model-generated harms such as unqualified advice and PII", + "credentials": { + "type": "object", + "properties": { + "apiKey": { + "type": "string", + "label": "API Key", + "description": "Find your API key in the Mistral la-plateforme", + "encrypted": true + } + }, + "required": ["apiKey"] + }, + "functions": [ + { + "name": "Moderate Content", + "id": "moderateContent", + "type": "guardrail", + "supportedHooks": ["beforeRequestHook", "afterRequestHook"], + "description": [ + { + "type": "subHeading", + "text": "Checks if the content passes the mentioned content moderation checks." + } + ], + "parameters": { + "type": "object", + "properties": { + "categories": { + "type": "array", + "label": "Moderation Checks", + "description": [ + { + "type": "subHeading", + "text": "Select the categories that should NOT be allowed in the content. (Checked via OpenAI moderation API)" + } + ], + "items": { + "type": "string", + "enum": [ + "sexual", + "hate_and_discrimination", + "violence_and_threats", + "dangerous_and_criminal_content", + "selfharm", + "health", + "financial", + "law", + "pii" + ], + "default": [ + "selfharm", + "pii", + "sexual", + "hate_and_discrimination" + ] + } + } + }, + "required": ["categories"] + } + } + ] +} diff --git a/plugins/mistral/mistral.test.ts b/plugins/mistral/mistral.test.ts new file mode 100644 index 000000000..87ad7e417 --- /dev/null +++ b/plugins/mistral/mistral.test.ts @@ -0,0 +1,102 @@ +import { PluginContext } from '../types'; +import testCreds from './.creds.json'; +import { mistralGuardrailHandler } from './index'; + +function getParameters() { + return { + credentials: testCreds, + }; +} + +describe('mistral guardrail handler', () => { + it('should fail if the apiKey is invalid', async () => { + const eventType = 'beforeRequestHook'; + const context = { + request: { text: 'this is a test string for moderations' }, + }; + const parameters = JSON.parse(JSON.stringify(getParameters())); + parameters.credentials.apiKey = 'invalid-api-key'; + + const result = await mistralGuardrailHandler( + context as unknown as PluginContext, + parameters, + eventType, + { env: {} } + ); + + expect(result).toBeDefined(); + expect(result.verdict).toBe(true); + expect(result.error).toBeDefined(); + expect(result.data).toBeNull(); + }); + + it('should success and return the flagged categories', async () => { + const eventType = 'beforeRequestHook'; + const context = { + request: { + text: 'my name is John Doe and my email is john.doe@example.com', + }, + }; + const parameters = JSON.parse(JSON.stringify(getParameters())); + parameters.categories = ['pii']; + + const result = await mistralGuardrailHandler( + context as unknown as PluginContext, + parameters, + eventType, + { env: {} } + ); + + expect(result).toBeDefined(); + expect(result.verdict).toBe(false); + expect(result.error).toBeDefined(); + expect(result.data).toMatchObject({ flagged_categories: ['pii'] }); + }); + + it('should include the multiple flagged categories in the response', async () => { + const eventType = 'beforeRequestHook'; + const context = { + request: { + text: 'my name is John Doe and my email is john.doe@example.com. I am a financial advisor and I suggest you to invest in the stock market in company A.', + }, + }; + const parameters = JSON.parse(JSON.stringify(getParameters())); + parameters.categories = ['pii', 'financial']; + + const result = await mistralGuardrailHandler( + context as unknown as PluginContext, + parameters, + eventType, + { env: {} } + ); + + expect(result).toBeDefined(); + expect(result.verdict).toBe(false); + expect(result.error).toBeDefined(); + expect(result.data).toMatchObject({ + flagged_categories: ['financial', 'pii'], + }); + }); + + it('should fail if the request body is invalid', async () => { + const eventType = 'beforeRequestHook'; + const context = { + request: { text: 'this is safe string without any flagged categories' }, + }; + + const parameters = JSON.parse(JSON.stringify(getParameters())); + parameters.categories = ['pii', 'financial']; + + const result = await mistralGuardrailHandler( + context as unknown as PluginContext, + parameters, + eventType, + { env: {} } + ); + + expect(result).toBeDefined(); + expect(result.verdict).toBe(true); + expect(result.error).toBeDefined(); + expect(result.data).toBeNull(); + }); +}); diff --git a/plugins/types.ts b/plugins/types.ts index 898366dee..252b0a0e6 100644 --- a/plugins/types.ts +++ b/plugins/types.ts @@ -1,5 +1,7 @@ export interface PluginContext { [key: string]: any; + requestType: 'complete' | 'chatComplete'; + provider: string; } export interface PluginParameters {