Skip to content

Commit

Permalink
Merge pull request #756 from b4s36t4/feat/mistral-guardrail
Browse files Browse the repository at this point in the history
feat: add mistral as a new guardrail provider
  • Loading branch information
VisargD authored Dec 26, 2024
2 parents 467c321 + d90eaf0 commit 66e9213
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 0 deletions.
4 changes: 4 additions & 0 deletions plugins/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { handler as patronusnoRacialBias } from './patronus/noRacialBias';
import { handler as patronusretrievalAnswerRelevance } from './patronus/retrievalAnswerRelevance';
import { handler as patronustoxicity } from './patronus/toxicity';
import { handler as patronuscustom } from './patronus/custom';
import { mistralGuardrailHandler } from './mistral';
import { handler as pangeatextGuard } from './pangea/textGuard';

export const plugins = {
Expand Down Expand Up @@ -81,6 +82,9 @@ export const plugins = {
toxicity: patronustoxicity,
custom: patronuscustom,
},
mistral: {
moderateContent: mistralGuardrailHandler,
},
pangea: {
textGuard: pangeatextGuard,
},
Expand Down
133 changes: 133 additions & 0 deletions plugins/mistral/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import {
HookEventType,
PluginContext,
PluginHandler,
PluginParameters,
} from '../types';
import { getText, post } from '../utils';

interface MistralResponse {
id: string;
model: string;
results: [
{
categories: {
sexual: boolean;
hate_and_discrimination: boolean;
violence_and_threats: boolean;
dangerous_and_criminal_content: boolean;
selfharm: boolean;
health: boolean;
financial: boolean;
law: boolean;
pii: boolean;
};
category_score: {
sexual: number;
hate_and_discrimination: number;
violence_and_threats: number;
dangerous_and_criminal_content: number;
selfharm: number;
health: number;
financial: number;
law: number;
pii: number;
};
},
];
}

type GuardrailFunction = keyof MistralResponse['results'][0]['categories'];

export const mistralGuardrailHandler: PluginHandler = async (
context: PluginContext,
parameters: PluginParameters,
eventType: HookEventType,
_options
) => {
let error = null;
let verdict = true;
let data = null;

const creds = parameters.credentials as Record<string, string>;
if (!creds.apiKey) {
return {
error: 'Mistral API key not provided.',
verdict: false,
data: null,
};
}

let model = 'mistral-moderation-latest';

if (parameters.model) {
// Model can be passed dynamically
model = parameters.model;
}

const checks = parameters.categories as GuardrailFunction[];

const text = getText(context, eventType);
const messages =
eventType === 'beforeRequestHook'
? context.request?.json?.messages
: context.response?.json?.messages;

// should contain text or should contain messages array
if (
(!text && !Array.isArray(messages)) ||
(Array.isArray(messages) && messages.length === 0)
) {
return {
error: 'Mistral: Invalid Request body',
verdict: false,
data: null,
};
}

// Use conversation guardrail if it's a chatcomplete and before hook
const shouldUseConversation =
eventType === 'beforeRequestHook' && context.requestType === 'chatComplete';
const url = shouldUseConversation
? 'https://api.mistral.ai/v1/chat/moderations'
: 'https://api.mistral.ai/v1/moderations';

try {
const request = await post<MistralResponse>(
url,
{
model: model,
...(!shouldUseConversation && { input: [text] }),
...(shouldUseConversation && { input: [messages] }),
},
{
headers: {
Authorization: `Bearer ${creds.apiKey}`,
'Content-Type': 'application/json',
},
}
);

const categories: Record<GuardrailFunction, boolean> =
request.results[0]?.categories ?? {};
const categoriesFlagged = Object.keys(categories).filter((category) => {
if (
checks.includes(category as GuardrailFunction) &&
!!categories[category as GuardrailFunction]
) {
return true;
}
return false;
});

if (categoriesFlagged.length > 0) {
verdict = false;
data = { flagged_categories: categoriesFlagged };
}
} catch (err) {
error = err;
verdict = true;
}

return { error, verdict, data };
};
66 changes: 66 additions & 0 deletions plugins/mistral/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"id": "mistral",
"description": "Mistral Content Moderation classifier leverages the most relevant policy categories for effective guardrails and introduces a pragmatic approach to LLM safety by addressing model-generated harms such as unqualified advice and PII",
"credentials": {
"type": "object",
"properties": {
"apiKey": {
"type": "string",
"label": "API Key",
"description": "Find your API key in the Mistral la-plateforme",
"encrypted": true
}
},
"required": ["apiKey"]
},
"functions": [
{
"name": "Moderate Content",
"id": "moderateContent",
"type": "guardrail",
"supportedHooks": ["beforeRequestHook", "afterRequestHook"],
"description": [
{
"type": "subHeading",
"text": "Checks if the content passes the mentioned content moderation checks."
}
],
"parameters": {
"type": "object",
"properties": {
"categories": {
"type": "array",
"label": "Moderation Checks",
"description": [
{
"type": "subHeading",
"text": "Select the categories that should NOT be allowed in the content. (Checked via OpenAI moderation API)"
}
],
"items": {
"type": "string",
"enum": [
"sexual",
"hate_and_discrimination",
"violence_and_threats",
"dangerous_and_criminal_content",
"selfharm",
"health",
"financial",
"law",
"pii"
],
"default": [
"selfharm",
"pii",
"sexual",
"hate_and_discrimination"
]
}
}
},
"required": ["categories"]
}
}
]
}
102 changes: 102 additions & 0 deletions plugins/mistral/mistral.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import { PluginContext } from '../types';
import testCreds from './.creds.json';
import { mistralGuardrailHandler } from './index';

function getParameters() {
return {
credentials: testCreds,
};
}

describe('mistral guardrail handler', () => {
it('should fail if the apiKey is invalid', async () => {
const eventType = 'beforeRequestHook';
const context = {
request: { text: 'this is a test string for moderations' },
};
const parameters = JSON.parse(JSON.stringify(getParameters()));
parameters.credentials.apiKey = 'invalid-api-key';

const result = await mistralGuardrailHandler(
context as unknown as PluginContext,
parameters,
eventType,
{ env: {} }
);

expect(result).toBeDefined();
expect(result.verdict).toBe(true);
expect(result.error).toBeDefined();
expect(result.data).toBeNull();
});

it('should success and return the flagged categories', async () => {
const eventType = 'beforeRequestHook';
const context = {
request: {
text: 'my name is John Doe and my email is john.doe@example.com',
},
};
const parameters = JSON.parse(JSON.stringify(getParameters()));
parameters.categories = ['pii'];

const result = await mistralGuardrailHandler(
context as unknown as PluginContext,
parameters,
eventType,
{ env: {} }
);

expect(result).toBeDefined();
expect(result.verdict).toBe(false);
expect(result.error).toBeDefined();
expect(result.data).toMatchObject({ flagged_categories: ['pii'] });
});

it('should include the multiple flagged categories in the response', async () => {
const eventType = 'beforeRequestHook';
const context = {
request: {
text: 'my name is John Doe and my email is john.doe@example.com. I am a financial advisor and I suggest you to invest in the stock market in company A.',
},
};
const parameters = JSON.parse(JSON.stringify(getParameters()));
parameters.categories = ['pii', 'financial'];

const result = await mistralGuardrailHandler(
context as unknown as PluginContext,
parameters,
eventType,
{ env: {} }
);

expect(result).toBeDefined();
expect(result.verdict).toBe(false);
expect(result.error).toBeDefined();
expect(result.data).toMatchObject({
flagged_categories: ['financial', 'pii'],
});
});

it('should fail if the request body is invalid', async () => {
const eventType = 'beforeRequestHook';
const context = {
request: { text: 'this is safe string without any flagged categories' },
};

const parameters = JSON.parse(JSON.stringify(getParameters()));
parameters.categories = ['pii', 'financial'];

const result = await mistralGuardrailHandler(
context as unknown as PluginContext,
parameters,
eventType,
{ env: {} }
);

expect(result).toBeDefined();
expect(result.verdict).toBe(true);
expect(result.error).toBeDefined();
expect(result.data).toBeNull();
});
});
2 changes: 2 additions & 0 deletions plugins/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export interface PluginContext {
[key: string]: any;
requestType: 'complete' | 'chatComplete';
provider: string;
}

export interface PluginParameters {
Expand Down

0 comments on commit 66e9213

Please sign in to comment.